# Line Search for the Wiener Filter Parameter

In this notebook, we are going to determine the $\lambda$ for the Wiener filter using a line search. The value we are looking for is the argmin of:
$$
\|Y-H\ast X_\lambda\|^2_2 + \lambda \|X_\lambda\|^2_2
$$
where $X_\lambda$ is the output of the Tikhonov (wiener) Filter $\mathcal{T}_\lambda(Y,H)$.

> For a sake of simplicity in the line search, we did not not consider a Laplacian regularization but the identity one. In the scope of this work, the precise value of $\lambda$ is not crucial. Any value of $\lambda$ that ensures a stable deconvolution is satisfactory. This choice is motivated by the approach in the [ForWaRD](https://ieeexplore.ieee.org/abstract/document/1261329?casa_token=9c8WAbl7hiAAAAAA:dOrXfVmlVGs6RxCdpnGULKOutryQ3n1dRgRc_Yug2Y_oNo4nGzwGKHgivpagDZfMZXigFYS5i2I3) method.

## Load requirements

In [1]:
%matplotlib inline
import sys

# Add library path to PYTHONPATH
lib_path = '/gpfswork/rech/xdy/uze68md/GitHub/'
path_alphatransform = lib_path+'alpha-transform'
path_score = lib_path+'score'
sys.path.insert(0, path_alphatransform)
sys.path.insert(0, path_score)

# Libraries
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy import fft
from scipy.ndimage import zoom
import cadmos_lib as cl
import tensorflow as tf
import galsim
from galsim import Image
import galsim.hsm
import galflow as gf
from galaxy2galaxy import problems

# Functions

def recons(batch, interp_factor=2):
    """Reconstruct observations from images filtered with Tikhonov"""
    # resize Tikhonov images
    tikho = np.array([zoom(t[...,0], zoom=interp_factor) for t in batch['inputs_tikho']])
    # apply real Fourier transform on them
    tikho = np.array([ np.fft.rfft2(t)for t in tikho])
    # multiply by the input PSF and divide by the target PSF
    recons = np.array([t*c[...,0]/h[...,0] for t,c,h in zip(tikho,batch['psf_cfht'],batch['psf_hst'])])
    # apply inverse real Fourier transform on the result
    recons = np.array([np.fft.irfft2(r) for r in recons])
    return recons,tikho

def ir2tf(imp_resp, shape):
    

    dim = 2
    # Zero padding and fill
    irpadded = np.zeros(shape)
    irpadded[tuple([slice(0, s) for s in imp_resp.shape])] = imp_resp
    # Roll for zero convention of the fft to avoid the phase
    # problem. Work with odd and even size.
    for axis, axis_size in enumerate(imp_resp.shape):

        irpadded = np.roll(irpadded,
                           shift=-int(np.floor(axis_size / 2)),
                           axis=axis)

    return fft.rfftn(irpadded, axes=range(-dim, 0))

def laplacian(shape):
    
    impr = np.zeros([3,3])
    for dim in range(2):
        idx = tuple([slice(1, 2)] * dim +
                    [slice(None)] +
                    [slice(1, 2)] * (1 - dim))
        impr[idx] = np.array([-1.0,
                              0.0,
                              -1.0]).reshape([-1 if i == dim else 1
                                              for i in range(2)])
    impr[(slice(1, 2), ) * 2] = 4.0
    return ir2tf(impr, shape), impr

def laplacian_tf(shape):
    return tf.convert_to_tensor(laplacian(shape)[0])

def wiener_tf(image, psf, balance, laplacian=True):
    r"""Applies Wiener filter to image.

    This function takes an image in the direct space and its corresponding PSF in the
    Fourier space and performs a deconvolution using the Wiener Filter.

    Parameters
    ----------
    image   : 2D TensorFlow tensor
        Image in the direct space.
    psf     : 2D TensorFlow tensor
        PSF in the Fourier space (or K space).
    balance : scalar
        Weight applied to regularization.
    laplacian : boolean
        If true the Laplacian regularization is used else the identity regularization 
        is used.

    Returns
    -------
    tuple
        The first element is the filtered image in the Fourier space.
        The second element is the PSF in the Fourier space (also know as the Transfer
        Function).
    """
    trans_func = psf
    if laplacian:
        reg = laplacian_tf(image.shape)
        if psf.shape != reg.shape:
            trans_func = tf.signal.rfft2d(tf.signal.ifftshift(tf.cast(psf, 'float32')))
        else:
            trans_func = psf
    
    arg1 = tf.cast(tf.math.conj(trans_func), 'complex64')
    arg2 = tf.dtypes.cast(tf.math.abs(trans_func),'complex64') ** 2
    arg3 = balance
    if laplacian:
        arg3 *= tf.dtypes.cast(tf.math.abs(laplacian_tf(image.shape)), 'complex64')**2
    wiener_filter = arg1 / (arg2 + arg3)
    
    # Apply wiener in Foutier (or K) space
    wiener_applied = wiener_filter * tf.signal.rfft2d(tf.cast(image, 'float32'))
    
    return wiener_applied, trans_func

def pre_proc_unet(dico):
    r"""Preprocess the data and apply the Tikhonov filter on the input galaxy images.

    This function takes the dictionnary of galaxy images and PSF for the input and
    the target and returns a list containing 2 arrays: an array of galaxy images that
    are the output of the Tikhonov filter and an array of target galaxy images.

    Parameters
    ----------
    dico : dictionnary
        Array_like means all those objects -- lists, nested lists, etc. --
        that can be converted to an array.  We can also refer to
        variables like `var1`.

    Returns
    -------
    list
        list containing 2 arrays: an array of galaxy images that are the output of the
        Tikhonov filter and an array of target galaxy images.

    Example
    -------
    These are written in doctest format, and should illustrate how to
    use the function.

    >>> from galaxy2galaxy import problems # to list avaible problems run problems.available()
    >>> problem128 = problems.problem('attrs2img_cosmos_hst2euclide')
    >>> dset = problem128.dataset(Modes.TRAIN, data_dir='attrs2img_cosmos_hst2euclide')
    >>> dset = dset.map(pre_proc_unet)
    """
    # First, we add noise
    # For the estimation of CFHT noise standard deviation check section 3 of:
    # https://github.com/CosmoStat/ShapeDeconv/blob/master/data/CFHT/HST2CFHT.ipynb
    sigma_cfht = 23.59
    noise = tf.random_normal(shape=tf.shape(dico['inputs']), mean=0.0, stddev=sigma_cfht, dtype=tf.float32)
    dico['inputs'] = dico['inputs'] + noise

    # Second, we interpolate the image on a finer grid
    x_interpolant=tf.image.ResizeMethod.BICUBIC
    interp_factor = 2
    Nx = 64
    Ny = 64
    dico['inputs_cfht'] = tf.image.resize(dico['inputs'],
                    [Nx*interp_factor,
                    Ny*interp_factor],
                    method=x_interpolant)
    # Since we lower the resolution of the image, we also scale the flux
    # accordingly
    dico['inputs_cfht'] = dico['inputs_cfht'] / interp_factor**2

    balance = 9e-3  # determined using line search
    dico['inputs_tikho'], _ = wiener_tf(dico['inputs_cfht'][...,0], dico['psf_cfht'][...,0], balance)
    dico['inputs_tikho'] = tf.expand_dims(dico['inputs_tikho'], axis=0)
    psf_hst = tf.reshape(dico['psf_hst'], [dico['psf_hst'].shape[-1],*dico['psf_hst'].shape[:2]])
    psf_hst = tf.cast(psf_hst, 'complex64')
    # gf.kconvolve performs a convolution in the K (Fourier) space
    # inputs are given in K space
    # the output is in the direct space
    dico['inputs_tikho'] = gf.kconvolve(dico['inputs_tikho'], psf_hst,zero_padding_factor=1,interp_factor=interp_factor)
    dico['inputs_tikho'] = dico['inputs_tikho'][0,...]

    return dico

def make_preproc(balance):
    def pre_proc_unet(dico):
        r"""Preprocess the data and apply the Tikhonov filter on the input galaxy images.

        This function takes the dictionnary of galaxy images and PSF for the input and
        the target and returns a list containing 2 arrays: an array of galaxy images that
        are the output of the Tikhonov filter and an array of target galaxy images.

        Parameters
        ----------
        dico : dictionnary
            Array_like means all those objects -- lists, nested lists, etc. --
            that can be converted to an array.  We can also refer to
            variables like `var1`.

        Returns
        -------
        list
            list containing 2 arrays: an array of galaxy images that are the output of the
            Tikhonov filter and an array of target galaxy images.

        Example
        -------
        These are written in doctest format, and should illustrate how to
        use the function.

        >>> from galaxy2galaxy import problems # to list avaible problems run problems.available()
        >>> problem128 = problems.problem('attrs2img_cosmos_hst2euclide')
        >>> dset = problem128.dataset(Modes.TRAIN, data_dir='attrs2img_cosmos_hst2euclide')
        >>> dset = dset.map(pre_proc_unet)
        """
        # First, we add noise
        # For the estimation of CFHT noise standard deviation check section 3 of:
        # https://github.com/CosmoStat/ShapeDeconv/blob/master/data/CFHT/HST2CFHT.ipynb
        sigma_cfht = 23.59
        noise = tf.random_normal(shape=tf.shape(dico['inputs']), mean=0.0, stddev=sigma_cfht, dtype=tf.float32)
        dico['inputs'] = dico['inputs'] + noise
    
        # Second, we interpolate the image on a finer grid
        x_interpolant=tf.image.ResizeMethod.BICUBIC
        interp_factor = 2
        Nx = 64
        Ny = 64
        dico['inputs_cfht'] = tf.image.resize(dico['inputs'],
                        [Nx*interp_factor,
                        Ny*interp_factor],
                        method=x_interpolant)
        # Since we lower the resolution of the image, we also scale the flux
        # accordingly
        dico['inputs_cfht'] = dico['inputs_cfht'] / interp_factor**2

        # balance = 10**(-2.16)  # best after old grid search performed by Hippolyte
        dico['inputs_tikho'], _ = wiener_tf(dico['inputs_cfht'][...,0], dico['psf_cfht'][...,0], balance)
        dico['inputs_tikho'] = tf.expand_dims(dico['inputs_tikho'], axis=0)
        psf_hst = tf.reshape(dico['psf_hst'], [dico['psf_hst'].shape[-1],*dico['psf_hst'].shape[:2]])
        psf_hst = tf.cast(psf_hst, 'complex64')
        # gf.kconvolve performs a convolution in the K (Fourier) space
        # inputs are given in K space
        # the output is in the direct space
        dico['inputs_tikho'] = gf.kconvolve(dico['inputs_tikho'], psf_hst,zero_padding_factor=1,interp_factor=interp_factor)
        dico['inputs_tikho'] = dico['inputs_tikho'][0,...]

        return dico
    return pre_proc_unet

## Prepare Dataset

In [2]:
# Let's create an instance of the hsc_problem
Modes = tf.estimator.ModeKeys
problem128 = problems.problem('attrs2img_cosmos_cfht2hst')

### Dataset Precheck

In [3]:
dset = problem128.dataset(Modes.EVAL, data_dir='/gpfswork/rech/xdy/uze68md/data/attrs2img_cosmos_cfht2hst/')
dset = dset.repeat()
dset = dset.map(make_preproc(10**(-2)))
n_batch = 128
dset = dset.batch(n_batch)
# Build an iterator over the dataset
iterator = dset.make_one_shot_iterator().get_next()
sess = tf.Session()
# Initialize batch
batch = sess.run(iterator)


INFO:tensorflow:Reading data files from /gpfswork/rech/xdy/uze68md/data/attrs2img_cosmos_cfht2hst/attrs2img_cosmos_cfht2hst-dev*
INFO:tensorflow:partition: 0 num_data_files: 2

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.


Instructions for updating:
Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.


### Magnitude Line Search

We start by performing a grid search to determine the magnitude (exponent) of the $\lambda$.

In [4]:
mags = [-3,-2,-1,0,1,2,3] # magnitudes
mag_opt = mags[0]
loss_min = -1
all_losses = []

interp_factor = 2.0

# resize observations
obs = np.array([zoom(i[...,0], zoom=interp_factor) for i in batch['inputs']])

# make Laplacian operator
lap_filter,_ = laplacian(obs.shape[-2:])

for mag in mags:
    
    dset = problem128.dataset(Modes.EVAL, data_dir='/gpfswork/rech/xdy/uze68md/data/attrs2img_cosmos_cfht2hst/')
    dset = dset.repeat()
    dset_tmp = dset.map(make_preproc(10**mag))

    n_batch = 128
    
    dset_tmp = dset_tmp.batch(n_batch)
    # Build an iterator over the dataset
    iterator = dset_tmp.make_one_shot_iterator().get_next()
    sess = tf.Session()
    # Initialize batch
    batch = sess.run(iterator)
    tikh_recons, tikho = recons(batch, interp_factor=interp_factor)
    error = obs - tikh_recons
    mse_list = np.array([np.mean(e**2) for e in error])
    loss_list = np.array([m + 10**mag * np.mean(np.fft.irfft2(t)**2) for m,t in zip(mse_list, tikho)])
    loss = np.mean(loss_list)
    
    # concatenate losses
    all_losses += [loss]
    
mag = mags[np.argmin(all_losses)]
print("The optimal magnitude is {}".format(mag))

INFO:tensorflow:Reading data files from /gpfswork/rech/xdy/uze68md/data/attrs2img_cosmos_cfht2hst/attrs2img_cosmos_cfht2hst-dev*
INFO:tensorflow:partition: 0 num_data_files: 2
INFO:tensorflow:Reading data files from /gpfswork/rech/xdy/uze68md/data/attrs2img_cosmos_cfht2hst/attrs2img_cosmos_cfht2hst-dev*
INFO:tensorflow:partition: 0 num_data_files: 2
INFO:tensorflow:Reading data files from /gpfswork/rech/xdy/uze68md/data/attrs2img_cosmos_cfht2hst/attrs2img_cosmos_cfht2hst-dev*
INFO:tensorflow:partition: 0 num_data_files: 2
INFO:tensorflow:Reading data files from /gpfswork/rech/xdy/uze68md/data/attrs2img_cosmos_cfht2hst/attrs2img_cosmos_cfht2hst-dev*
INFO:tensorflow:partition: 0 num_data_files: 2
INFO:tensorflow:Reading data files from /gpfswork/rech/xdy/uze68md/data/attrs2img_cosmos_cfht2hst/attrs2img_cosmos_cfht2hst-dev*
INFO:tensorflow:partition: 0 num_data_files: 2
INFO:tensorflow:Reading data files from /gpfswork/rech/xdy/uze68md/data/attrs2img_cosmos_cfht2hst/attrs2img_cosmos_cfht2

### Significand Line Search

Now we perform a line search to determine the significand (the first digit) of $\lambda$. 

In [5]:
values1 = np.array([5, 6 ,7 ,8 , 9])
values2 = np.array([1,2,3,4,5,6,7,8,9])
values = np.hstack([values1 * 10**(mag-1), values2 * 10**mag])

loss_min = -1
all_losses = []

interp_factor = 2.0

for v in values:
    
    dset = problem128.dataset(Modes.EVAL, data_dir='/gpfswork/rech/xdy/uze68md/data/attrs2img_cosmos_cfht2hst/')
    dset = dset.repeat()
    dset_tmp = dset.map(make_preproc(v))

    n_batch = 128
    
    dset_tmp = dset_tmp.batch(n_batch)
    # Build an iterator over the dataset
    iterator = dset_tmp.make_one_shot_iterator().get_next()
    sess = tf.Session()
    # Initialize batch
    batch = sess.run(iterator)
    tikh_recons,tikho = recons(batch, interp_factor=interp_factor)
    error = obs - tikh_recons
    mse_list = np.array([np.mean(e**2) for e in error])
    loss_list = np.array([m + 10**mag * np.mean(np.fft.irfft2(t)**2) for m,t in zip(mse_list, tikho)])
    loss = np.mean(loss_list)
    
    # concatenate losses
    all_losses += [loss]
    
value = values[np.argmin(all_losses)]
print("The optimal value is {}".format(value))

INFO:tensorflow:Reading data files from /gpfswork/rech/xdy/uze68md/data/attrs2img_cosmos_cfht2hst/attrs2img_cosmos_cfht2hst-dev*
INFO:tensorflow:partition: 0 num_data_files: 2
INFO:tensorflow:Reading data files from /gpfswork/rech/xdy/uze68md/data/attrs2img_cosmos_cfht2hst/attrs2img_cosmos_cfht2hst-dev*
INFO:tensorflow:partition: 0 num_data_files: 2
INFO:tensorflow:Reading data files from /gpfswork/rech/xdy/uze68md/data/attrs2img_cosmos_cfht2hst/attrs2img_cosmos_cfht2hst-dev*
INFO:tensorflow:partition: 0 num_data_files: 2
INFO:tensorflow:Reading data files from /gpfswork/rech/xdy/uze68md/data/attrs2img_cosmos_cfht2hst/attrs2img_cosmos_cfht2hst-dev*
INFO:tensorflow:partition: 0 num_data_files: 2
INFO:tensorflow:Reading data files from /gpfswork/rech/xdy/uze68md/data/attrs2img_cosmos_cfht2hst/attrs2img_cosmos_cfht2hst-dev*
INFO:tensorflow:partition: 0 num_data_files: 2
INFO:tensorflow:Reading data files from /gpfswork/rech/xdy/uze68md/data/attrs2img_cosmos_cfht2hst/attrs2img_cosmos_cfht2