# Image Reconstruction by Deconvolution, Denoising and Fusion

In [None]:
from PIL import Image
from jupyter_compare_view import compare
import numpy as np

# Load the image
image_path = '../data/EPFL_aereal.jpeg'
ground_truth_image = Image.open(image_path)
ground_truth_image

In [None]:
# Convert the image to a numpy array
image = np.array(ground_truth_image)[-500:, -1000:].astype(np.float32) / 255.

dim_shape = image.shape
print(dim_shape)

In [None]:
import matplotlib.pyplot as plt
fig, axs = plt.subplots(3, 1, sharex=True, sharey=True)
for i, c in zip(range(3), ["r","g","b"]):
    axs[i].hist(image[..., i].ravel(), color=c)

## Lens 


In [None]:
import scipy.signal as ss
import pyxu.operator as pxo

wsize=20
filt = ss.get_window('hamming', wsize) 
psf = filt / np.sum(filt)

lens_op = pxo.Stencil(
    dim_shape=dim_shape,
    kernel=[psf, psf, np.array(1.)], 
    center=[wsize//2, wsize//2, 0], 
)

print(lens_op.lipschitz)

In [None]:
# Convolution with PSF
data_lens = lens_op(image) 

In [None]:
compare(data_lens, image, circle_fraction=0.1)

### Benchmark: Convolution with Pyxu vs. PyLops

In [None]:
import timeit

In [None]:
%%timeit 
lens_op(image)

In [None]:
from scipy.signal import convolve as conv_scipy

In [None]:
%%timeit
convolved_rows = conv_scipy(  # ... along rows
    image,
    psf[:, np.newaxis, np.newaxis],
    mode="same",
    method="direct",
)

convolved_cols = conv_scipy(  # ... then along columns
    convolved_rows,
    psf[np.newaxis, :, np.newaxis],
    mode="same",
    method="direct",
)

## Pinhole photography

In [None]:
# Pinhole 
 
# Tapering
tukey_window = np.outer(ss.get_window('tukey', image.shape[1]), ss.get_window('tukey', image.shape[0]))
taper_op = pxo.DiagonalOp(np.tile(tukey_window, (3, 1, 1)).T)
taper_op.lipschitz = taper_op.estimate_lipschitz()

# Partial Masking 
mask = np.ones(image.shape)
mask[100:250, 200:300] = 0.1
mask[250:300, 700:750] = 0.3
mask_op = pxo.DiagonalOp(vec=mask, dim_shape=dim_shape)
mask_op.lipschitz = mask_op.estimate_lipschitz()

pinhole_op = mask_op * taper_op
pinhole_op.lipschitz

In [None]:
data_pinhole = pinhole_op(image) # Low exposure 
rng = np.random.default_rng()
data_pinhole = np.clip(rng.normal(loc=data_pinhole, scale=np.sqrt(data_pinhole)/4), a_min=0, a_max=1) # shot noise

In [None]:
compare(data_pinhole, image, circle_fraction=0.1)

In [None]:
data_pinhole.shape

In [None]:
data_merge = np.concatenate([data_pinhole[:, :500], data_lens[:, 500:]], axis=1)
compare(data_merge, image, circle_fraction=0.1)

## Pseudo-inverse Solution

In [None]:
import pyxu.operator.blocks as pxb

sensing_op = pxb.stack([0.7*lens_op,0.3*pinhole_op])

In [None]:
sensing_op

In [None]:
data = np.stack([0.7*data_lens, 0.3*data_pinhole])
data.shape

In [None]:
adj_recon = sensing_op.adjoint(data)

In [None]:
compare(adj_recon, data_merge, circle_fraction=0.1)

In [None]:
import pyxu.opt.stop as pxst

stop_crit = pxst.RelError(eps=1e-3, var="x", f=None, norm=2, satisfy_all=True) | pxst.MaxIter(500)

pinv_solution = sensing_op.pinv(data, damp=0.01,
                         kwargs_init=dict(verbosity=500),
                         kwargs_fit=dict(stop_crit=stop_crit))

In [None]:
compare(pinv_solution.clip(0, 1), data_merge, circle_fraction=0.1)

## Bayesian Inversion

## Likelihood

In [None]:
theta = 10, 0.006

loss_lens = theta[0] * pxo.SquaredL2Norm(dim_shape=dim_shape).argshift(-data_lens) * lens_op
loss_pinhole = theta[1] * pxo.L1Norm(dim_shape=dim_shape).moreau_envelope(0.01).argshift(-data_pinhole) * pinhole_op

loss = loss_lens + loss_pinhole

### Prior

In [None]:
# Define multi-channel TV prior
grad = pxo.Gradient(dim_shape=dim_shape, directions=(0, 1), accuracy=4, mode='constant')
lambda_= .005
l21_norm = lambda_ * pxo.L21Norm(dim_shape=grad.codim_shape, l2_axis=(0, 3))

# Positivity constraint
pos_constraint = pxo.PositiveOrthant(dim_shape=dim_shape)

In [None]:
loss.diff_lipschitz

In [None]:
import pyxu.opt.solver as pxsol

# Stopping criterion
default_stop_crit = (pxst.RelError(eps=1e-3, var="x", f=None, norm=2, satisfy_all=True) & 
                    pxst.RelError(eps=1e-3, var="z", f=None, norm=2, satisfy_all=True) &
                    pxst.MaxIter(20)) | pxst.MaxIter(1000)

# Initialize solver (Condat-Vu primal-dual splitting algorithm in this case)
solver = pxsol.CondatVu(f=loss, g=pos_constraint, h=l21_norm, K=grad, verbosity=100)

# Fit  
solver.fit(
    x0=pinv_solution, 
    tuning_strategy=2, 
    stop_crit=default_stop_crit
)
isotv_solution = solver.solution()

In [None]:
compare(isotv_solution.clip(0, 1), data_merge, circle_fraction=0.1)

In [None]:
compare(isotv_solution.clip(0, 1), pinv_solution.clip(0, 1), circle_fraction=0.1)

### Plug and Play

In [None]:
from pyxu.abc.operator import ProxFunc
from pyxu.info.deps import NDArrayInfo

class MedianFilter(ProxFunc):
    def __init__(self, dim_shape, filter_size=3):
        super().__init__(dim_shape=dim_shape, codim_shape=(1,))
        self._filter_size = filter_size
        
    def apply(self, arr):
        return NotImplemented
        
    def prox(self, arr, tau=None):
        if NDArrayInfo.from_obj(arr) == NDArrayInfo.NUMPY:
            import scipy.ndimage as cpu_ndimage
            median_filter = cpu_ndimage.median_filter
        if NDArrayInfo.from_obj(arr) == NDArrayInfo.CUPY:
            import cupyx.scipy.ndimage as gpu_ndimage
            median_filter = gpu_ndimage.median_filter
        return median_filter(arr, size=self._filter_size)

median_op = MedianFilter(dim_shape=dim_shape, filter_size=(3, 3, 1))

In [None]:
solver = pxsol.PGD(f=loss, g=median_op, verbosity=100)

solver.fit(x0=pinv_solution, acceleration=True, stop_crit=pxst.MaxIter(250))
pnp_recons = solver.solution()

In [None]:
compare(pnp_recons.clip(0, 1), isotv_solution.clip(0, 1), circle_fraction=0.1)

### Interoperability with PyTorch

In [None]:
from torch import nn
# Define the autoencoder
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1),  # b, 16, 16, 16
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),  # b, 32, 8, 8
            nn.ReLU(True),
            nn.Conv2d(32, 64, 7)  # b, 64, 2, 2
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 7),  # b, 32, 8, 8
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),  # b, 16, 16, 16
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1, output_padding=1),  # b, 3, 32, 32
            nn.Sigmoid()  # Compress the output to a pixel range of [0, 1]
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
        
model = Autoencoder()
model.load_state_dict(torch.load("../data/nn/ae_trained.pth"))
model.eval()

In [None]:
import torch
from pyxu.operator.interop import from_torch

def denoiser_prox(arr, tau):
    with torch.no_grad():
        torch_input = arr.reshape(1, *dim_shape).permute(0, 3, 1, 2)
        im_denoised = model(torch_input)
        torch_output = im_denoised.squeeze().permute(1, 2, 0)
    return torch_output

nn_denoiser = from_torch(
    apply=None,
    prox=denoiser_prox,
    dim_shape=dim_shape,
    codim_shape=1,
    cls=ProxFunc,
    dtype="float32",
    enable_warnings=True,
    name='nn_denoiser',
)

In [None]:
denoised_pinv = nn_denoiser.prox(pinv_solution, None)
compare(denoised_pinv.clip(0,1), pinv_solution.clip(0,1))

In [None]:
solver = pxsol.PGD(f=loss, g=nn_denoiser)
import pyxu.runtime as pxrt
with pxrt.Precision(pxrt.Width.SINGLE): 
    solver.fit(x0=pinv_solution, acceleration=True, stop_crit=pxst.MaxIter(100))

pytorch_recons = solver.solution()

In [None]:
compare(pytorch_recons.clip(0, 1), pinv_solution.clip(0, 1))