From b9058b9100a7580414c1df374d4226a9541a1354 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Fri, 21 Apr 2023 19:05:52 +0200 Subject: [PATCH] First commit. --- lensless/apgd_v2.py | 278 +++++++++++++++++++++++++++++++++++++++++ recon_requirements.txt | 3 +- 2 files changed, 280 insertions(+), 1 deletion(-) create mode 100644 lensless/apgd_v2.py diff --git a/lensless/apgd_v2.py b/lensless/apgd_v2.py new file mode 100644 index 00000000..715826d8 --- /dev/null +++ b/lensless/apgd_v2.py @@ -0,0 +1,278 @@ +from lensless.recon import ReconstructionAlgorithm +import inspect +import numpy as np +from pycsou.func.loss import SquaredL2Loss +from pycsou.func.penalty import NonNegativeOrthant, SquaredL2Norm, L1Norm +from pycsou.opt.proxalgs import APGD as APGD_pyc +from copy import deepcopy +from pycsou.core.linop import LinearOperator +from typing import Union, Optional +from numbers import Number +from scipy import fft +from scipy.fftpack import next_fast_len + + +import pycsou.abc as pyca +import pycsou.operator.func as func +import pycsou.opt.solver as solver +import pycsou.opt.stop as stop +import pycsou.runtime as pycrt +import pycsou.util as pycu +import pycsou.util.ptype as pyct + + +class APGDPriors: + """ + Priors (compatible with Pycsou) for APGD. + + See Pycsou documentation for available penalties: + https://matthieumeo.github.io/pycsou/html/api/functionals/pycsou.func.penalty.html?highlight=penalty#module-pycsou.func.penalty + """ + + L2 = "l2" + NONNEG = "nonneg" + L1 = "l1" + + @staticmethod + def all_values(): + vals = [] + for i in inspect.getmembers(APGDPriors): + # remove private and protected functions, and this function + if not i[0].startswith("_") and not callable(i[1]): + vals.append(i[1]) + return vals + + +class RealFFTConvolve2D(pyca.LinOp): + def __init__(self, filter: pyct.NDArray, dtype: Optional[type] = None): + """ + Linear operator that performs convolution in Fourier domain, and assumes + real-valued signals. + + Parameters + ---------- + filter :py:class:`~numpy.ndarray` + 2D filter to use. Must be of shape (height, width, channels) even if + only one channel. + dtype : float32 or float64 + Data type to use for optimization. + """ + + assert len(filter.shape) == 3 + self._filter_shape = np.array(filter.shape) + self._n_channels = filter.shape[2] + + # cropping / padding indices + self._padded_shape = 2 * self._filter_shape[:2] - 1 + self._padded_shape = np.array([next_fast_len(i) for i in self._padded_shape]) + self._padded_shape = np.r_[self._padded_shape, [self._n_channels]] + self._start_idx = (self._padded_shape[:2] - self._filter_shape[:2]) // 2 + self._end_idx = self._start_idx + self._filter_shape[:2] + + # precompute filter in frequency domain + self._H = fft.rfft2(self._pad(filter), axes=(0, 1)) + self._Hadj = np.conj(self._H) + self._padded_data = np.zeros(self._padded_shape).astype(dtype) + + shape = (int(np.prod(self._filter_shape)), int(np.prod(self._filter_shape))) + super(RealFFTConvolve2D, self).__init__(shape=shape) + + def _crop(self, x): + return x[self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1]] + + def _pad(self, v): + vpad = np.zeros(self._padded_shape).astype(v.dtype) + vpad[self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1]] = v + return vpad + + @pycrt.enforce_precision(i="arr") + @pycu.vectorize(i="arr") + def apply(self, x: pyct.NDArray) -> pyct.NDArray: + self._padded_data[ + self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1] + ] = np.reshape(x, self._filter_shape) + y = self._crop( + fft.ifftshift( + fft.irfft2(fft.rfft2(self._padded_data, axes=(0, 1)) * self._H, axes=(0, 1)), + axes=(0, 1), + ) + ) + return y.ravel() + + @pycrt.enforce_precision(i="arr") + @pycu.vectorize(i="arr") + def adjoint(self, y: pyct.NDArray) -> pyct.NDArray: + self._padded_data[ + self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1] + ] = np.reshape(y, self._filter_shape) + x = self._crop( + fft.ifftshift( + fft.irfft2(fft.rfft2(self._padded_data, axes=(0, 1)) * self._Hadj, axes=(0, 1)), + axes=(0, 1), + ) + ) + return x.ravel() + + +class APGD(ReconstructionAlgorithm): + def __init__( + self, + psf, + max_iter=500, + dtype=np.float32, + diff_penalty=None, + prox_penalty=APGDPriors.NONNEG, + acceleration="BT", + diff_lambda=0.001, + prox_lambda=0.001, + **kwargs + ): + """ + Wrapper for `Pycsou's APGD `__ + (accelerated proximal gradient descent) applied to lensless imaging. + + Parameters + ---------- + psf : :py:class:`~numpy.ndarray` + PSF that models forward propagation. + max_iter : int, optional + Maximal number of iterations. + dtype : float32 or float64 + Data type to use for optimization. + diff_penalty : None or str or :py:class:`~pycsou.core.functional.DifferentiableFunctional` + Differentiable functional to serve as prior / regularization term. + Default is None. See `Pycsou documentation `__ + for available penalties. + prox_penalty : None or str or :py:class:`~pycsou.core.functional.ProximableFunctional` + Proximal functional to serve as prior / regularization term. Default + is non-negative prior. See `Pycsou documentation `__ + for available penalties. + acceleration : [None, 'BT', 'CD'] + Which acceleration scheme should be used (None for no acceleration). + "BT" (Beck and Teboule) has convergence `O(1/k^2)`, while "CD" + (Chambolle and Dossal) has convergence `o(1/K^2)`. So "CD" should be + faster. but from our experience "BT" gives better results. + diff_lambda : float + Weight of differentiable penalty. + prox_lambda : float + Weight of proximal penalty. + """ + + # PSF and data are the same size / shape + self._original_shape = psf.shape + self._original_size = psf.size + + self._apgd = None + self._gen = None + + super(APGD, self).__init__(psf, dtype) + + self._max_iter = max_iter + + # Convolution operator + self._H = RealFFTConvolve2D(self._psf, dtype=dtype) + + # initialize solvers which will be created when data is set + if diff_penalty is not None: + if diff_penalty == APGDPriors.L2: + self._diff_penalty = diff_lambda * func.SquaredL2Norm(dim=self._H.shape[1]) + else: + assert hasattr(diff_penalty, "jacobianT") + self._diff_penalty = diff_lambda * diff_penalty(dim=self._H.shape[1]) + else: + self._diff_penalty = None + + if prox_penalty is not None: + if prox_penalty == APGDPriors.L1: + self._prox_penalty = prox_lambda * func.L1Norm(dim=self._H.shape[1]) + elif prox_penalty == APGDPriors.NONNEG: + self._prox_penalty = prox_lambda * func.PositiveOrthant(dim=self._H.shape[1]) + else: + try: + self._prox_penalty = prox_lambda * prox_penalty(dim=self._H.shape[1]) + except ValueError: + print("Unexpected prior.") + else: + self._prox_penalty = None + + self._acc = acceleration + + def set_data(self, data): + """ + For ``APGD``, we use data to initialize problem for Pycsou. + + Parameters + ---------- + data : :py:class:`~numpy.ndarray` + Lensless data on which to iterate to recover an estimate of the + scene. Should match provide PSF, i.e. shape and 2D (grayscale) or + 3D (RGB). + + """ + if not self._is_rgb: + assert len(data.shape) == 2 + data = data[:, :, np.newaxis] + assert len(self._psf_shape) == len(data.shape) + self._data = data + + """ Set up problem """ + # Cost function + loss = (1 / 2) * func.SquaredL2Norm(dim=self._H.shape[0]).asloss(self._data.ravel()) + F = loss * self._H + if self._diff_penalty is not None: + F += self._diff_penalty + + # if self._prox_penalty is not None: + # G = self._prox_penalty + # else: + # G = None + + self._apgd = solver.PGD( + f=F, + g=self._prox_penalty, + show_progress=False, + ) + + # self._apgd.fit( # BLOCK-ing mode for simplicity + # x0=np.zeros(F.shape[1]), + # # x0=rng.normal(size=F.dim), + # stop_crit=sc, + # track_objective=True, + # mode=pyca.solver.Mode.MANUAL if config.save_inter else pyca.solver.Mode.BLOCK, + # ) + + # self._apgd = APGD_pyc(dim=dim, F=F, G=G, acceleration=self._acc) + + # -- setup to print progress report + self._apgd.old_iterand = deepcopy(self._apgd.init_iterand) + self._apgd.update_diagnostics() + self._gen = self._apgd.iterates(n=self._max_iter) + + def reset(self): + self._image_est = np.zeros(self._original_size, dtype=self._dtype) + if self._apgd is not None: + self._apgd.reset() + + # -- setup to print progress report + self._apgd.old_iterand = deepcopy(self._apgd.init_iterand) + self._apgd.update_diagnostics() + self._gen = self._apgd.iterates(n=self._max_iter) + + def _progress(self): + """ + Pycsou has functionality for printing progress that we will make use of + here. + + """ + self._apgd.update_diagnostics() + self._apgd.old_iterand = deepcopy(self._apgd.iterand) + self._apgd.print_diagnostics() + + def _update(self): + next(self._gen) + self._image_est[:] = self._apgd.iterand["iterand"] + + def _form_image(self): + image = self._image_est.reshape(self._original_shape) + image[image < 0] = 0 + return image diff --git a/recon_requirements.txt b/recon_requirements.txt index 2d148c64..01f6a7c1 100644 --- a/recon_requirements.txt +++ b/recon_requirements.txt @@ -4,4 +4,5 @@ pylops==1.18.0 scikit-image==0.19.0rc0 hydra-core click>=8.0.1 -waveprop>=0.0.3 # for simulation \ No newline at end of file +waveprop>=0.0.3 # for simulation +pip install git+https://github.com/matthieumeo/pycsou.git@v2-dev \ No newline at end of file