Skip to content

Commit

Permalink
Merge pull request #157 from carterbox/probe-options
Browse files Browse the repository at this point in the history
API: Remove leading dimensions from arrays in ptycho.reconstruct
  • Loading branch information
carterbox committed Aug 9, 2021
2 parents 34843e3 + eb63a1b commit c68700a
Show file tree
Hide file tree
Showing 14 changed files with 4,996 additions and 1,738 deletions.
6,404 changes: 4,811 additions & 1,593 deletions docs/source/examples/ptycho.ipynb

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions src/tike/communicators/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def MPIio(self, scan, *args):

# Generate the mask
mask = np.logical_and(
edges[self.rank] < scan[0, :, 0],
scan[0, :, 0] <= edges[self.rank + 1])
edges[self.rank] < scan[:, 0],
scan[:, 0] <= edges[self.rank + 1])

scan = scan[:, mask]
split_args = [arg[:, mask] for arg in args]
scan = scan[mask]
split_args = [arg[mask] for arg in args]

return (scan, *split_args)
6 changes: 3 additions & 3 deletions src/tike/operators/cupy/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def fwd(
assert positions.shape[:-2] == patches.shape[:-3], (positions.shape,
patches.shape)
assert positions.shape[-2] * nrepeat == patches.shape[-3]
assert positions.shape[-1] == 2
assert positions.shape[-1] == 2, positions.shape
assert images.dtype == 'complex64'
assert patches.dtype == 'complex64'
assert positions.dtype == 'float32'
nimage = np.prod(images.shape[:-2])
nimage = int(np.prod(images.shape[:-2]))
grids = (
positions.shape[-2],
nimage,
Expand Down Expand Up @@ -118,7 +118,7 @@ def adj(
assert images.dtype == 'complex64'
assert patches.dtype == 'complex64'
assert positions.dtype == 'float32'
nimage = np.prod(images.shape[:-2])
nimage = int(np.prod(images.shape[:-2]))
grids = (
positions.shape[-2],
nimage,
Expand Down
8 changes: 4 additions & 4 deletions src/tike/operators/cupy/propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class Propagation(CachedFFT, Operator):
farplane: (..., detector_shape, detector_shape) complex64
The wavefronts hitting the detector respectively.
Shape for cost functions and gradients is
(ntheta, nscan, 1, 1, detector_shape, detector_shape).
data, intensity : (ntheta, nscan, detector_shape, detector_shape) complex64
(nscan, 1, 1, detector_shape, detector_shape).
data, intensity : (nscan, detector_shape, detector_shape) complex64
data is the square of the absolute value of `farplane`. `data` is the
intensity of the `farplane`.
Expand Down Expand Up @@ -86,12 +86,12 @@ def _gaussian_cost(self, data, intensity):
def _gaussian_grad(self, data, farplane, intensity, overwrite=False):
return farplane * (
1 - np.sqrt(data) / (np.sqrt(intensity) + 1e-32)
)[:, :, np.newaxis, np.newaxis] # yapf:disable
)[..., np.newaxis, np.newaxis, :, :] # yapf:disable

def _poisson_cost(self, data, intensity):
return np.mean(intensity - data * np.log(intensity + 1e-32))

def _poisson_grad(self, data, farplane, intensity, overwrite=False):
return farplane * (
1 - data / (intensity + 1e-32)
)[:, :, np.newaxis, np.newaxis] # yapf: disable
)[..., np.newaxis, np.newaxis, :, :] # yapf: disable
4 changes: 2 additions & 2 deletions src/tike/operators/cupy/ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _compute_intensity(self, data, psi, scan, probe):
)
return self.xp.sum(
(farplane * farplane.conj()).real,
axis=(2, 3),
axis=tuple(range(1, farplane.ndim - 2)),
), farplane

def cost(self, data, psi, scan, probe) -> float:
Expand Down Expand Up @@ -174,6 +174,6 @@ def grad_probe(self, data, psi, scan, probe, mode=None):
scan=scan,
overwrite=True,
),
axis=1,
axis=0,
keepdims=True,
)
4 changes: 2 additions & 2 deletions src/tike/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def batch_indicies(n, m=1, use_random=True):

def get_batch(x, b, n):
"""Returns x[:, b[n]]; for use with map()."""
return x[:, b[n]]
return x[b[n]]


def put_batch(y, x, b, n):
"""Assigns y into x[:, b[n]]; for use with map()."""
x[:, b[n]] = y
x[b[n]] = y


def adagrad(g, v=None, eps=1e-6):
Expand Down
4 changes: 3 additions & 1 deletion src/tike/ptycho/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,6 @@
"""
from .ptycho import *
from .position import check_allowed_positions
from .object import ObjectOptions
from .position import check_allowed_positions, PositionOptions
from .probe import ProbeOptions
39 changes: 37 additions & 2 deletions src/tike/ptycho/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,29 @@

import cupy as cp
import cupyx.scipy.ndimage
import numpy as np

logger = logging.getLogger(__name__)


# TODO: Use dataclass decorator when python 3.6 reaches EOL
class ObjectOptions:
"""Manage data and setting related to object correction."""
"""Manage data and setting related to object correction.
Attributes
----------
positivity_constraint : float [0, 1]
This value is passed to the tike.ptycho.object.positivity_constraint
function.
smoothness_constraint : float [0, 1/8]
This value is passed to the tike.ptycho.object.smoothness_constraint
function.
"""

def __init__(self, positivity_constraint=0, smoothness_constraint=0):
self.positivity_constraint = positivity_constraint
self.smoothness_constraint = positivity_constraint
self.smoothness_constraint = smoothness_constraint


def positivity_constraint(x, r):
Expand Down Expand Up @@ -67,3 +79,26 @@ def smoothness_constraint(x, a):
else:
raise ValueError(
f"Smoothness constraint must be in range [0, 1/8) not {a}.")


def get_padded_object(scan, probe):
"""Return a ones-initialized object and shifted scan positions.
An complex object array is initialized with shape such that the area
covered by the probe is padded on each edge by a full probe width. The scan
positions are shifted to be centered in this newly initialized object
array.
"""
# Shift scan positions to zeros
scan[..., 0] -= np.min(scan[..., 0])
scan[..., 1] -= np.min(scan[..., 1])

# Add padding to scan positions of field-of-view / 8
span = np.max(scan[..., 0]), np.max(scan[..., 1])
scan[..., 0] += probe.shape[-2]
scan[..., 1] += probe.shape[-1]

height = 3 * probe.shape[-2] + int(span[0])
width = 3 * probe.shape[-1] + int(span[1])

return np.ones((height, width), dtype='complex64'), scan
30 changes: 5 additions & 25 deletions src/tike/ptycho/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,34 +124,13 @@ def check_allowed_positions(scan, psi, probe_shape):
)
if np.any(less_than_one) or np.any(greater_than_psi):
x = np.logical_or(less_than_one, greater_than_psi)
raise ValueError("These scan positions exist outside field of view:\n"
raise ValueError("Scan positions must be positive valued "
"and fit within the field of view "
"with at least a 1 pixel buffer around the edge. "
"These scan positions exist outside field of view:\n"
f"{scan[np.logical_or(x[..., 0], x[..., 1])]}")


def get_padded_object(scan, probe):
"""Return a ones-initialized object and shifted scan positions.
An complex object array is initialized with shape such that the area
covered by the probe is padded on each edge by a full probe width. The scan
positions are shifted to be centered in this newly initialized object
array.
"""
# Shift scan positions to zeros
scan[..., 0] -= np.min(scan[..., 0])
scan[..., 1] -= np.min(scan[..., 1])

# Add padding to scan positions of field-of-view / 8
span = np.max(scan[..., 0]), np.max(scan[..., 1])
scan[..., 0] += probe.shape[-2]
scan[..., 1] += probe.shape[-1]

ntheta = probe.shape[0]
height = 3 * probe.shape[-2] + int(span[0])
width = 3 * probe.shape[-1] + int(span[1])

return np.ones((ntheta, height, width), dtype='complex64'), scan


def update_positions_pd(operator, data, psi, probe, scan,
dx=-1, step=0.05): # yapf: disable
"""Update scan positions using the gradient of intensity method.
Expand Down Expand Up @@ -227,6 +206,7 @@ def update_positions_pd(operator, data, psi, probe, scan,

from cupy.fft.config import get_plan_cache


def _image_grad(x):
"""Return the gradient of the x for each of the last two dimesions."""
# FIXME: Use different gradient approximation that does not use FFT. Because
Expand Down
31 changes: 27 additions & 4 deletions src/tike/ptycho/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,29 @@
logger = logging.getLogger(__name__)


class ProbeOptions:
"""Manage data and setting related to probe correction.
Attributes
----------
orthogonality_constraint : bool
Forces probes to be orthogonal each iteration.
num_eigen_probes : int
The number of eigen probes/components.
"""

def __init__(self, num_eigen_probes=0, orthogonality_constraint=True):
self.orthogonality_constraint = orthogonality_constraint
self._weights = None
self._eigen_probes = None
if num_eigen_probes > 0:
pass

@property
def num_eigen_probes(self):
return 0 if self._weights is None else self._weights.shape[-2]


def get_varying_probe(shared_probe, eigen_probe=None, weights=None):
"""Construct the varying probes.
Expand Down Expand Up @@ -149,9 +172,9 @@ def update_eigen_probe(comm, R, eigen_probe, weights, patches, diff, β=0.1):
def _get_update(R, eigen_probe, weights):
# (..., POSI, 1, 1, 1, 1) to match other arrays
weights = weights[..., None, None, None, None]
norm_weights = np.linalg.norm(weights[0], axis=-5, keepdims=True)**2
norm_weights = np.linalg.norm(weights, axis=-5, keepdims=True)**2

if np.all(norm_weights[0] == 0):
if np.all(norm_weights == 0):
raise ValueError('eigen_probe weights cannot all be zero?')

# FIXME: What happens when weights is zero!?
Expand Down Expand Up @@ -275,7 +298,7 @@ def add_modes_random_phase(probe, nmodes):
Parameters
----------
probe : (:, :, :, M, :, :) array
probe : (..., M, :, :) array
A probe with M > 0 incoherent modes.
nmodes : int
The number of desired modes.
Expand Down Expand Up @@ -333,7 +356,7 @@ def init_varying_probe(scan, shared_probe, N):
shared_probe.shape[-3],
).astype('float32')
weights -= np.mean(weights, axis=-3, keepdims=True)
weights[0, :] = 1.0 # The weight of the first eigen probe is non-zero
weights[..., 0, :] = 1.0 # The weight of the first eigen probe is non-zero

return eigen_probe, weights

Expand Down

0 comments on commit c68700a

Please sign in to comment.