In [None]:
#=================================================================================================================
# NEW Effects utilities:
#=================================================================================================================


def add_gaussian_2D(image_data, width, amp):
    lx, ly = image_data.shape
    l_pix, s_pix = np.sort([lx, ly])
    X, Y = np.ogrid[0:lx, 0:ly]
    Z = (norm.pdf(X, loc=lx/2., scale=width*l_pix)*norm.pdf(Y, loc=ly/2., scale=width*l_pix))
    Z /= Z.max()
    image_data += amp*Z
    return image_data


def add_whitehole(image, origin_image, width):
    lx, ly = image.shape
    l_pix, s_pix = np.sort([lx, ly])
    X, Y = np.ogrid[0:lx, 0:ly]
    white_mask = (X - lx / 2) ** 2 + (Y - ly / 2) ** 2 > lx * ly / 4 * (width)**2
    image[~white_mask] = norm_image_percentil(origin_image[~white_mask], perc_down=0., perc_up=70., perc_norm_up=10.)
    return image


def add_blackmask(image, width, amp):
    lx, ly = image.shape
    l_pix, s_pix = np.sort([lx, ly])
    X, Y = np.ogrid[0:lx, 0:ly]
    aperture_mask = (X - lx / 2) ** 2 + (Y - ly / 2) ** 2 > lx * ly / 4 * (width)**2
    image[~aperture_mask] = amp
    return image


def add_fade_ring(image, width1, width2, amp1, amp2):
    lx, ly = image.shape
    X, Y = np.ogrid[0:lx, 0:ly]
    distances = np.sqrt((X - lx / 2) ** 2 + (Y - ly / 2) ** 2)
    fadering = np.sqrt((X - lx / 2) ** 2 + (Y - ly / 2) ** 2) - (width1*lx/2.)
    fadering = fadering/((width2*lx/2.) - (width1*lx/2.))*amp2 + amp1
    aperture_mask = (distances > width1*lx/2.) & (distances <= width2*lx/2.)
    image[aperture_mask] = fadering[aperture_mask]
    return image


def add_blackhole(image_data, width, amp):
    lx, ly = image_data.shape
    X, Y = np.ogrid[0:lx, 0:ly]
    Z = (norm.pdf(X, loc=lx/2., scale=width*lx)*norm.pdf(Y, loc=ly/2., scale=width*ly))
    #aperture_mask = (X - lx / 2) ** 2 + (Y - ly / 2) ** 2 > lx * ly / 4 * (width)**2
    #image_data[~aperture_mask] = 1.
    Z /= Z.max()
    image_data -= amp*Z
    image_data[image_data < 1.] = 1.
    return image_data


def prepare_border(image_data, width_border):
    lx, ly = image_data.shape
    X, Y = np.ogrid[0:lx, 0:ly]
    mask = np.full(np.shape(image_data), 1.)
    mask *= ((1-np.exp(-X/width_border))*(1-np.exp(-Y/width_border)))
    mask *= ((1-np.exp(-(lx-X)/width_border))*(1-np.exp(-(ly-Y)/width_border)))
    image_data *= mask
    return image_data


#=================================================================================================================
# Effects utilities:
#=================================================================================================================

#-----------------------------------------------------------------------------------------------------------------
def convolve_image(image_data, size_kernel=5.):
    """
    Returns image after convolving with a kernel.
    http://astropy.readthedocs.io/en/latest/convolution/index.html
    """
    
    kernel = Gaussian2DKernel(size_kernel)
    #kernel = AiryDisk2DKernel(size_kernel)
    #kernel = MexicanHat2DKernel(size_kernel)
    #kernel = Tophat2DKernel(size_kernel)
    
    result_image_data = convolve(image_data, kernel, boundary='extend')
    
    return result_image_data


#-----------------------------------------------------------------------------------------------------------------
def apply_aperture_mask(image_data, aperture=1., inverse=False):
    """
    Returns image after aplying a hard cutoff mask.
    """

    lx, ly = image_data.shape
    X, Y = np.ogrid[0:lx, 0:ly]
    
    aperture_mask = (X - lx / 2) ** 2 + (Y - ly / 2) ** 2 > lx * ly / 4 * (aperture)**2
    
    if not inverse:
        image_data[aperture_mask] = 0.
    elif inverse:
        image_data[~aperture_mask] = 0.

    return image_data


#-----------------------------------------------------------------------------------------------------------------
def fft_filtering(image_data, filter_aperture, inverse):
    """
    Returns image after aplying a filter in Fourier space.
    """

    
    def expand_image(image):

        lx, ly = image.shape
        img_mean = np.mean(image)
        big_image = np.full((3*lx, 3*ly), img_mean)
        big_image[lx:2*lx, ly:2*ly] = image

        return big_image

    def reduce_image(image):

        lx, ly = image.shape
        lx = int(lx/3)
        ly = int(ly/3)
        small_image = image[lx:2*lx, ly:2*ly]

        return small_image

    image_data = np.array(image_data, dtype=np.float)

    image_data = expand_image(image_data)
    # Take the fourier transform of the image:
    F1 = fftpack.fft2(image_data)
    # Now shift the quadrants around so that low spatial frequencies are in
    # the center of the 2D fourier transformed image:
    F2 = fftpack.fftshift(F1)
    # Apply filter to range of frequencies:
    F3 = apply_aperture_mask(F2, filter_aperture, inverse)
       
    # Reconstruct image:
    image_data = np.abs(fftpack.ifft2(F3))
    image_data = reduce_image(image_data)
    
    #image_data = np.array(image_data, dtype=np.float16)
    
    return image_data


#-----------------------------------------------------------------------------------------------------------------
def fancy_mask_effect(image_data):

    lx, ly = image_data.shape
    X, Y = np.ogrid[0:lx, 0:ly]
    Z = (norm.pdf(X, loc=lx/2., scale=lx)*norm.pdf(Y, loc=ly/2., scale=ly))
    Z = Z/Z.max()
    Z = (Z - Z[0,0])/Z[np.int(2.5*np.sqrt(lx)), np.int(2.5*np.sqrt(ly))]
    Z[Z > 1.] = 1.

    img_min = image_data.min()
    image_data = (image_data - img_min)*Z + img_min

    return image_data

    
#-----------------------------------------------------------------------------------------------------------------
def crop_borders_effect(image_data, crop_amp):

    x_pix, y_pix = image_data.shape
    pixels = np.int(crop_amp)
    image_data = image_data[pixels : x_pix - pixels, pixels : y_pix - pixels]

    return image_data


#-----------------------------------------------------------------------------------------------------------------
def norm_image_percentil(image_data, perc_down=0., perc_up=100., perc_norm_up=-1):
    """
    Returns image after aplying a percentil normalization.
    Default: perc_down=0., perc_up=100., norm_up=-1
    """
    
    image_min = image_data.min()
    
    interval = AsymmetricPercentileInterval(perc_down, perc_up)
    
    if perc_norm_up < 0.:
        image_max = image_data.max()
    elif perc_norm_up > 0.:
        image_max = perc_norm_up
        
    image_data = (image_max - image_min)*interval(image_data)
    image_data = image_data + image_min
    
    return image_data


#-----------------------------------------------------------------------------------------------------------------
def norm_image_absolute(image_data, nabs_down=1., nabs_up=np.inf):
    """
    Returns image after aplying an absolute value normalization.
    Default: nabs_down=1., nabs_up=np.inf
    """
       
    image_data = image_data - image_data.min() + nabs_down
    image_data[image_data > nabs_up] = nabs_up     
        
    return image_data


#----------------------------------------------------------------------------------------------------------------- 
def norm_image_histogram(image_data, norm_down=1., norm_up=20.):
    """
    Returns image after aplying a histogram selection normalization.
    """
       
    img_hist = astrostats.histogram(image_data, bins='scott')
    arg = img_hist[0].argmax()
    if arg == 0:
        print('***WARNING: most common pixel is the one with lowest flux')
    arg += 3
    norm_val = (img_hist[1][arg] + img_hist[1][arg+1])/2.
    
    image_data = image_data - norm_val + norm_down
    image_data[image_data < norm_down] = norm_down
    
    return image_data


#=================================================================================================================
# Effects:
#=================================================================================================================


def effect_0(image):
    return image


def effect_base(image):
    image = norm_image_absolute(image, nabs_up=150.)
    image = norm_image_percentil(image, perc_down=30., perc_up=100., perc_norm_up=-1)
    return image


def effect_mask(image, amp):
    image = fancy_mask_effect(image)
    return image


def effect_fft(image, amp, aperture):
    if amp > 0.1:
        image = fft_filtering(image, filter_aperture=0.1*aperture, inverse=False)
        image = norm_image_absolute(image, nabs_down=1.)
        image = norm_image_percentil(image, perc_down=50., perc_up=100.)
    return image


def effect_fft_inv(image, amp, aperture):
    if amp > 0.1:
        image = norm_image_absolute(image, nabs_up=300.)
        image = fft_filtering(image, filter_aperture=0.1*aperture, inverse=True)
        image = norm_image_absolute(image, nabs_down=1.)
        image = norm_image_percentil(image, perc_down=50., perc_up=100.)
    return image


def effect_blackhole(image, width):
    width *= 0.02
    image = norm_image_absolute(image, nabs_up=10.)
    image = add_gaussian_2D(image, width, 100.*image.max())
    image = add_blackmask(image, 5.8*width, 1.)
    image = add_fade_ring(image, 5.5*width, 5.8*width, 1., 10.)
    return image


def effect_whitehole(image, width):
    width *= 0.02
    image = norm_image_absolute(image, nabs_up=10.)
    origin_image = image.copy()
    image = add_gaussian_2D(image, width, 100.*image.max())
    image = add_whitehole(image, origin_image, 5.8*width)
    image = add_fade_ring(image, 5.5*width, 5.8*width, 1., 10.)
    return image


def effect_blackwhitehole(image, width, whiteon):
    if width > 0.001 and whiteon == 1:
        image = effect_whitehole(image, width)
    elif width > 0.001 and whiteon == 0:
        image = effect_blackhole(image, width)
    return image


def effect_singularity(image, amp, aperture):
    if amp > 0.001:
        width = 0.01
        base_amp = 10.
        image = norm_image_absolute(image, nabs_up=10.)
        image = norm_image_percentil(image, perc_down=98.5, perc_up=100., perc_norm_up=10.)
        image = add_gaussian_2D(image, width, amp)
        image = fft_filtering(image, filter_aperture=0.06*aperture, inverse=True)
        image = norm_image_absolute(image, nabs_down=1.)
        image = norm_image_percentil(image, perc_down=20., perc_up=100., perc_norm_up=15.)
        image = crop_borders_effect(image, amp)
    return image


def effect_singularity_v1(image, amp, aperture):
    if amp > 0.001:
        width = 0.01
        image = norm_image_absolute(image, nabs_up=10.)
        image = norm_image_percentil(image, perc_down=98.5, perc_up=100., perc_norm_up=10.)
        image = add_gaussian_2D(image, width, 15.)
        image = fft_filtering(image, filter_aperture=0.108*amp, inverse=True)
        image = norm_image_absolute(image, nabs_down=1.)
        image = norm_image_percentil(image, perc_down=20., perc_up=100., perc_norm_up=15.)
    return image


def effect_fade_to_black(image, amp):
    if amp < 0.99:
        image = (image - 1.)*np.full(np.shape(image), amp) + np.ones(np.shape(image))
    return image


#def effect_plato(image, width, amp):
#    if width > 0.001:
#        base_amp = 1000.
#        width *= 0.02
#        image = add_whitehole(image, width, amp*base_amp)
#        image = fft_filtering(image, filter_aperture=0.4, inverse=False)
#    return image


def effect_prepare_border(image, width_border):
    image = norm_image_absolute(image, nabs_up=10.)
    image = prepare_border(image, width_border)
    return image

    