Skip to content

Commit

Permalink
Merge pull request #189 from carterbox/position
Browse files Browse the repository at this point in the history
Port position correction to RPIE
  • Loading branch information
carterbox committed Feb 15, 2022
2 parents 4b3fb15 + 9d4984c commit 54f1d6c
Show file tree
Hide file tree
Showing 8 changed files with 299 additions and 66 deletions.
18 changes: 18 additions & 0 deletions src/tike/communicators/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings

import cupy as cp
import numpy as np


class ThreadPool(ThreadPoolExecutor):
Expand Down Expand Up @@ -71,6 +72,10 @@ def _copy_to(self, x, worker: int) -> cp.array:
with cp.cuda.Device(worker):
return self.xp.asarray(x)

def _copy_host(self, x, worker: int) -> np.array:
with cp.cuda.Device(worker):
return self.xp.asnumpy(x)

def bcast(self, x: list, s=1) -> list:
"""Send each x to all device groups.
Expand Down Expand Up @@ -101,6 +106,19 @@ def gather(self, x: list, worker=None, axis=0) -> cp.array:
axis,
)

def gather_host(self, x: list, axis=0) -> np.array:
"""Concatenate x on host along the given axis."""
if self.num_workers == 1:
return cp.asnumpy(x[0])

def f(x, worker):
return self._copy_host(x, worker)

return np.concatenate(
self.map(f, x, self.workers),
axis,
)

def all_gather(self, x: list, axis=0) -> list:
"""Concatenate x on all workers along the given axis."""

Expand Down
13 changes: 5 additions & 8 deletions src/tike/ptycho/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,17 @@ def get_padded_object(scan, probe):
"""Return a ones-initialized object and shifted scan positions.
An complex object array is initialized with shape such that the area
covered by the probe is padded on each edge by a full probe width. The scan
covered by the probe is padded on each edge by a half probe width. The scan
positions are shifted to be centered in this newly initialized object
array.
"""
pad = probe.shape[-1] // 2
# Shift scan positions to zeros
scan[..., 0] -= np.min(scan[..., 0])
scan[..., 1] -= np.min(scan[..., 1])
scan = scan - np.min(scan, axis=-2) + pad

# Add padding to scan positions of field-of-view / 8
span = np.max(scan[..., 0]), np.max(scan[..., 1])
scan[..., 0] += probe.shape[-2]
scan[..., 1] += probe.shape[-1]

height = 3 * probe.shape[-2] + int(span[0])
width = 3 * probe.shape[-1] + int(span[1])
height = probe.shape[-1] + int(span[0]) + pad
width = probe.shape[-1] + int(span[1]) + pad

return np.ones((height, width), dtype='complex64'), scan
29 changes: 14 additions & 15 deletions src/tike/ptycho/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging

import cupy as cp
import cupyx.scipy.fft
import numpy as np

import tike.linalg
Expand All @@ -15,10 +16,7 @@
class PositionOptions:
"""Manage data and settings related to position correction."""

num_positions: int
"""The number of scanning positions."""

initial_scan: np.array = None
initial_scan: np.array
"""The original scan positions before they were updated using position
correction."""

Expand All @@ -37,12 +35,15 @@ class PositionOptions:

def __post_init__(self):
if self.use_adaptive_moment:
self._momentum = np.zeros((self.num_positions, 4), dtype='float32')
self._momentum = np.zeros(
(*self.initial_scan.shape[:-1], 4),
dtype='float32',
)

def split(self, indices):
"""Split the PositionOption meta-data along indices."""
new = PositionOptions(
0,
self.initial_scan[..., indices, :],
use_adaptive_moment=self.use_adaptive_moment,
vdecay=self.vdecay,
mdecay=self.mdecay,
Expand All @@ -54,18 +55,21 @@ def split(self, indices):

def join(self, other, indices):
"""Replace the PositionOption meta-data with other data."""
self.initial_scan[..., indices, :] = other.initial_scan
if self.use_adaptive_moment:
self._momentum[..., indices, :] = other._momentum
return self

def copy_to_device(self):
"""Copy to the current GPU memory."""
self.initial_scan = cp.asarray(self.initial_scan)
if self.use_adaptive_moment:
self._momentum = cp.asarray(self._momentum)
return self

def copy_to_host(self):
"""Copy to the host CPU memory."""
self.initial_scan = cp.asnumpy(self.initial_scan)
if self.use_adaptive_moment:
self._momentum = cp.asnumpy(self._momentum)
return self
Expand Down Expand Up @@ -200,18 +204,13 @@ def update_positions_pd(operator, data, psi, probe, scan,
return scan, cost


from cupy.fft.config import get_plan_cache


def _image_grad(x):
"""Return the gradient of the x for each of the last two dimesions."""
# FIXME: Use different gradient approximation that does not use FFT. Because
# FFT caches are per-thread and per-device, using FFT is inefficient.
ramp = 2j * cp.pi * cp.linspace(-0.5, 0.5, x.shape[-1], dtype='float32')
cache = get_plan_cache()
cache.set_size(0)
grad_x = cp.fft.ifft2(ramp * cp.fft.fft2(x))
grad_y = cp.fft.ifft2(ramp[:, None] * cp.fft.fft2(x))
grad_x = cupyx.scipy.fft.ifft2(ramp[:, None] * cupyx.scipy.fft.fft2(x))
grad_y = cupyx.scipy.fft.ifft2(ramp * cupyx.scipy.fft.fft2(x))
return grad_x, grad_y


Expand Down Expand Up @@ -290,7 +289,7 @@ def affine_position_regularization(
images=cp.zeros(psi.shape, dtype='complex64'),
positions=updated,
)
total_illumination = cp.fft.fft2(total_illumination)
total_illumination = cupyx.scipy.fft.fft2(total_illumination)
total_illumination *= _gaussian_frequency(
sigma=sigma,
size=total_illumination.shape[-1],
Expand All @@ -299,7 +298,7 @@ def affine_position_regularization(
sigma=sigma,
size=total_illumination.shape[-2],
)[..., None]
total_illumination = cp.fft.ifft2(total_illumination)
total_illumination = cupyx.scipy.fft.ifft2(total_illumination)
illum_proj = op.diffraction.patch.fwd(
images=total_illumination,
positions=updated,
Expand Down
13 changes: 1 addition & 12 deletions src/tike/ptycho/ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def reconstruct(
"across processes.")
else:
mpi = None
(psi, scan) = get_padded_object(scan, probe) if psi is None else (psi, scan)
check_allowed_positions(scan, psi, probe.shape)
with cp.cuda.Device(num_gpu[0] if isinstance(num_gpu, tuple) else None):
operator = Ptycho(
Expand Down Expand Up @@ -334,7 +333,6 @@ def _setup(
scan,
data,
eigen_weights,
initial_scan,
) = split_by_scan_grid(
comm.pool,
(
Expand All @@ -344,7 +342,6 @@ def _setup(
scan,
data,
eigen_weights,
None if position_options is None else position_options.initial_scan,
)
result = dict(
psi=comm.pool.bcast([psi.astype('complex64')]),
Expand All @@ -370,10 +367,6 @@ def _setup(
PositionOptions.copy_to_device,
(position_options.split(x) for x in comm.order),
)
if initial_scan is None:
position_options.initial_scan = comm.pool.map(cp.copy, scan)
else:
position_options.initial_scan = initial_scan

# Unique batch for each device
batches = comm.pool.map(
Expand Down Expand Up @@ -466,10 +459,6 @@ def _teardown(
comm.order,
):
position_options.join(x, o)
position_options.initial_scan = comm.pool.gather(
position_options.initial_scan,
axis=-2,
)[reorder].get()

return dict(
algorithm_options=algorithm_options,
Expand All @@ -486,7 +475,7 @@ def _teardown(
probe_options=result['probe_options'].copy_to_host()
if probe_options is not None else None,
psi=result['psi'][0].get(),
scan=comm.pool.gather(scan, axis=-2)[reorder].get(),
scan=comm.pool.gather_host(scan, axis=-2)[reorder],
)


Expand Down

0 comments on commit 54f1d6c

Please sign in to comment.