<center><img src="../docs/assets/animated_logo.png" align="center" height="100" width="100">

# [`Optimus-Primal`](https://github.com/astro-informatics/Optimus-Primal) - __Custom Operator__ Interactive Tutorial
---

In this interactive tutorial we demonstrate basic usage of `optimusprimal` for a 2-dimensional noisy inpainting problem with a custom measurement operator.


How to run a basic 2D unconstrained proximal primal-dual solver, this time with a custom measurement operator $\Phi$. 
We consider the canonical problem $y = \Phi x + n$ where $n \sim \mathcal{N}$, and $\Phi$ encodes the 
forward-model of the problem. This inverse problem can be solved via the unconstrained optimisation 

$$
\min_x [ ||(\Phi x-y)/\sigma||^2_2 + \lambda ||\Psi^{\dagger} x||_1 ]
$$

where $x \in \mathbb{R}^2$ is an a priori ground truth 2D signal, $y \in \mathbb{R}^2$ 
are simulated noisy observations, and $\lambda$ is the regularisation parameter which acts as 
a Lagrangian multiplier, balancing between data-fidelity and prior information. Before we begin, we 
need to import ``optimusprimal`` and some example specific packages


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import scipy.misc as misc 
from scipy.signal import resample

import optimusprimal.primal_dual as primal_dual
import optimusprimal.grad_operators as grad_operators
import optimusprimal.linear_operators as linear_operators
import optimusprimal.prox_operators as prox_operators

First, we need to define some heuristics for the solver, these include:

      - tol: convergence criteria for the iterations
      - iter: maximum number of iterations
      - update_iter: iterations between logging iteration diagnostics
      - record_iters: whether to record the full diagnostic information



In [None]:
options = {"tol": 1e-5, "iter": 5000, "update_iter": 50, "record_iters": False}

At this point lets set up the forward-model of the problem, which we will 
encode to the measurement operator $\Phi$.
For this problem lets consider the forward model which is a masking in 
pixel-space, in which case this inverse problem becomes a noisy 
in-painting problem!



In [None]:
class custom_phi(linear_operators.LinearOperator):
    """A custom linear operator e.g. a custom measurement operator"""

    def __init__(self, dim, masking):
        """Initialise the operator with any necessary parameters"""
        self.dim = dim

        # Generate a random mask
        mask = np.full(dim**2, False)
        mask[:int(masking*dim**2)] = True
        np.random.shuffle(mask)
        self.mask = mask
    
    def dir_op(self, x):
        """Forward linear operator"""
        return x.flatten('C')[self.mask]
    
    def adj_op(self, x):
        """Adjoint linear operator"""
        f = np.zeros(self.dim**2)
        f[self.mask] = x
        return f.reshape(self.dim, self.dim)

Next, we simulate a noisy in-painting setting by contaminating simulated 
observations $y$, of a known
signal $x$, with some Gaussianly distributed noise $n$.



In [None]:
ISNR = 30.0                                            # Input signal to noise ratio
sigma = 10 ** (-ISNR / 20.0)                           # Noise standard deviation
reg_param = 4.5                                        # Regularisation parameter
res = 256                                              # Resolution we want to work with
masking = 0.5                                          # Fraction of the observations to use
phi = custom_phi(dim=res, masking=masking)             # Custom forward-model of the problem

x = misc.ascent()                                      # Scipy's ascent benchmark image
for i in range(2):
    x = resample(x, axis=i, num=res)
x /= np.nanmax(x)                                      # Normalise image

y = phi.dir_op(x)                                      # Simulated observations y
n = np.random.normal(0, sigma, y.shape)                # Random Gaussian noise
y += n                                                 # Contaminate y with noise

For the unconstrained problem with Gaussian noise the data-fidelity constraint
is given by the gradient of the $\ell_2$-norm. Here we set up a gradient operator
corresponding to a gradient of the $\ell_2$-norm.



In [None]:
g = grad_operators.l2_norm(sigma, y, phi)

We regularise this inverse problem by adopting a wavelet sparsity $\ell_1$-norm prior.
To do this we first define what wavelets we wish to use, in this case a
combination of Daubechies family wavelets, and which levels to consider.
Any combination of wavelet families available by the [`PyWavelet`](https://tinyurl.com/5n7wzpmb) 
package may be selected.



In [None]:
wav = ["db1", "db4"]                                     # Wavelet dictionaries to combine
levels = 4                                               # Wavelet levels to consider [1-6]
psi = linear_operators.dictionary(wav, levels, x.shape)  # Wavelet linear operator

Next we construct the $\ell_1$-norm proximal operator which we pass the wavelets
($\Psi$) as a dictionary in which to compute the $\ell_1$-norm. We also add an
additional reality constraint f for good measure, as we know a priori our
signal $x$ is real.



In [None]:
h = prox_operators.l1_norm(np.max(np.abs(psi.dir_op(phi.adj_op(y)))) * reg_param, psi)
f = prox_operators.real_prox()

Finally we run the optimisation...



In [None]:
# Note that phi_adj_op(y) is a dirty first estimate. In practice one may wish to begin the optimisation from a better first guess!
best_estimate, diagnostics = primal_dual.FBPD(phi.adj_op(y), options, g, f, h)

...and plot the results!



In [None]:
def eval_snr(x, x_est):
    if np.array_equal(x, x_est):
        return 0
    num = np.sqrt(np.sum(np.abs(x) ** 2))
    den = np.sqrt(np.sum(np.abs(x - x_est) ** 2))
    return round(20*np.log10(num/den), 2)

fig, axs = plt.subplots(1, 3, figsize=[10, 5])

titles = ["Data", "Truth", "Reconstruction"]
est = [phi.adj_op(y), x, best_estimate]

for i in range(3):
    axs[i].imshow(est[i], cmap="magma", vmax=np.max(x), vmin=np.min(x))
    axs[i].set_title(titles[i], fontsize=14)
    axs[i].set_xlabel("SNR: {}dB,".format(eval_snr(x, est[i])), fontsize=12)

    plt.setp(axs[i].get_xticklabels(), visible=False)
    plt.setp(axs[i].get_yticklabels(), visible=False)

plt.show()