In [None]:
import time as t
import warnings
from functools import partial

import matplotlib.pyplot as plt
import napari
import numpy as np
import scipy as sp
import scipy.signal
import skimage.io
from jupyter_compare_view import compare
from pycsou.abc import DiffFunc, DiffMap, LinOp, Map, ProxFunc
from pycsou.operator import SquaredL2Norm
from pycsou.operator.interop import from_sciop, from_source, from_torch
from pycsou.operator.interop.torch import *
from pycsou.runtime import Precision, Width, enforce_precision
from pycsou.util import get_array_module, to_NUMPY

warnings.filterwarnings("ignore")

plt.style.use("seaborn-darkgrid")
plt.rcParams["figure.figsize"] = [9, 6]
plt.rcParams["figure.dpi"] = 150
plt.rcParams["axes.grid"] = False
plt.rcParams["image.cmap"] = "viridis"

rng = np.random.default_rng(seed=0)


def monochromatic(im, chan=0):
    out = im.copy()
    xp = get_array_module(im)
    mask = xp.ones(im.shape, dtype=bool)
    mask[chan] = False
    out[mask] = 0
    return out


def imshow(im, rgb=False):
    im = to_NUMPY(im)
    if im.ndim > 2 and rgb:
        plt.subplot(2, 2, 1)
        plt.imshow(np.moveaxis(im, 0, -1))
        plt.subplot(2, 2, 2)
        plt.imshow(np.moveaxis(monochromatic(im, 0), 0, -1))
        plt.subplot(2, 2, 3)
        plt.imshow(np.moveaxis(monochromatic(im, 1), 0, -1))
        plt.subplot(2, 2, 4)
        plt.imshow(np.moveaxis(monochromatic(im, 2), 0, -1))
    elif im.ndim > 2 and not rgb:
        plt.imshow(np.moveaxis(im, 0, -1))
    else:
        plt.imshow(im, cmap="gray")
    plt.axis("off")


def imshow_compare(*images, **kwargs):
    images = [to_NUMPY(im) for im in images]
    images = [np.clip(im, 0, 1) for im in images]
    images = [np.moveaxis(im, 0, -1) if im.ndim > 2 else im for im in images]
    return compare(
        *images, height=700, add_controls=True, display_format="jpg", **kwargs
    )


warnings.filterwarnings("ignore")

<p align="center">
<img src="https://matthieumeo.github.io/pycsou/html/_images/pycsou.png" alt= “” width=65%>
</p>

# A High Performance Computational Imaging Framework for Python

In [None]:
# Load data

from utils import downsample_volume, epfl_deconv_data

y, psf = [], []
for channel in range(3):
    y_, psf_ = epfl_deconv_data(channel)
    y_ = downsample_volume(y_, 2)
    psf_ = downsample_volume(psf_, 2)

    import cupy as cp

    y_ = cp.asarray(y_)
    psf_ = cp.asarray(psf_)

    # Same preprocessing as in Scico
    y_ -= y_.min()
    y_ /= y_.max()
    psf_ /= psf_.sum()

    y.append(y_)
    psf.append(psf_)

y = cp.stack(y)
psf = cp.stack(psf)

In [None]:
print(f"{y.shape=}")
print(f"{psf.shape=}")

In [None]:
viewer = napari.view_image(y.get().T, rgb=True)
viewer = napari.view_image(psf.get().T, rgb=True)

In [None]:
# Scico pads and creates a mask

# padding = [[0, p] for p in snp.array(psf.shape) - 1]
# y_pad = snp.pad(y, padding)
# mask = snp.pad(snp.ones_like(y), padding)

## Compute pseudo-inverse solution

In [None]:
# Create operators.
# M = linop.Diagonal(mask)

import pycsou.runtime as pycrt
from pycsou.operator import block_diag
from pycsou.operator.linop.base import DiagonalOp
from pycsou.operator.linop.fft import FFT
from pycsou.util.complex import view_as_complex, view_as_real

arg_shape = psf[0].shape
ndim = psf[0].ndim
size = psf[0].size

fft = FFT(arg_shape=arg_shape, axes=tuple(np.arange(ndim)), real=True)
fft.lipschitz(tight=True)

fft = block_diag([fft, fft, fft])

psf_fourier = DiagonalOp(view_as_real(abs(view_as_complex(fft(psf.ravel()))) + 0j))

convolve = (1 / size) * fft.T * psf_fourier * fft

In [None]:
y_pinv = convolve.pinv(
    y.ravel(), damp=1e-2, kwargs_init=dict(show_progress=False)
).reshape(y.shape)
y_pinv /= y_pinv.max()

In [None]:
imshow_compare(y[..., y.shape[3] // 2], y_pinv[..., y.shape[3] // 2])

In [None]:
viewer = napari.Viewer()
viewer.add_image(y.get(), name="original")
viewer.add_image(y_pinv.get(), name="pinv")

# Demo: Bayesian Image Deconvolution
$$\arg \min_{\mathbf{x}} \; \frac{1}{2} \|\mathbf{y} - \mathbf{F} \mathbf{x}
  \|_2^2 + \lambda \| \nabla \mathbf{x} \|_{2,1} +
  \iota_{\mathrm{+}}(\mathbf{x}) \;,$$

In [None]:
from pycsou.operator import Gradient, L1Norm, L21Norm, PositiveOrthant

range_constraint = PositiveOrthant(dim=y.size)

sl2 = SquaredL2Norm(dim=y.size).asloss(y.ravel())
sl2.diff_lipschitz()
grad = Gradient(arg_shape=arg_shape, diff_method="gd", sigma=[2.0, 2.0, 1.0], gpu=True)
grad = block_diag([grad, grad, grad])
grad.lipschitz(tight=False, tol=0.1)


# l1 = L1Norm(dim=grad.codim)
l21 = L21Norm(
    arg_shape=(
        3,
        ndim,
    )
    + arg_shape,
    l2_axis=(1),
)  # Compute l2 norm on gradient vector


loss = sl2 * convolve
loss.diff_lipschitz()

In [None]:
from pycsou.opt.solver import CV
from pycsou.opt.stop import MaxIter, RelError

# Stopping criterion
default_stop_crit = (
    RelError(eps=1e-2, var="x", f=None, norm=2, satisfy_all=True)
    & RelError(eps=1e-2, var="z", f=None, norm=2, satisfy_all=True)
    & MaxIter(20)
) | MaxIter(100)

λ = 2e-6  # parameter borrowed from Scico example
# Initialize solver (Condat-Vu primal-dual splitting algorithm in this case)
solver = CV(
    f=loss, g=range_constraint, h=λ * l21, K=grad, show_progress=True, verbosity=50
)

# Fit
with pycrt.Precision(pycrt.Width.SINGLE):
    solver.fit(x0=cp.zeros(y.size), tuning_strategy=2, stop_crit=default_stop_crit)
    y_tv = solver.solution().reshape(y.shape)
    y_tv /= y_tv.max()

In [None]:
imshow_compare(y[..., y.shape[3] // 2], y_tv[..., y.shape[3] // 2])

In [None]:
viewer = napari.Viewer()
viewer.add_image(y.get().T, name="original", rgb=True)
viewer.add_image(y_tv.get().T, name="TV", rgb=True)