Pablo Jimeno - UPV/EHU 2016

In [None]:
##################################################################################################################
### Misc utilities:
##################################################################################################################

#-------------------------------------------------------------------  
def scale_image(scale, *args):
    """
    If scale == 'mpc', *args: img_z, img_pix, img_mpc
    If scale == 'arcsec', *args: img_z, img_pix, img_arcsec
    """
    
    arcsecs_in_rad = 360.*3600./(2.*np.pi)
        
    if scale == 'mpc':
        img_z, img_pix, img_mpc = args
        ad_dist = Planck13.angular_diameter_distance(img_z).value
        img_rads = img_mpc/ad_dist
        img_arcsec = img_rads*arcsecs_in_rad
    elif scale == 'arcsec':
        img_z, img_pix, img_arcsec = args
        img_rads = img_arcsec/arcsecs_in_rad
        ad_dist = Planck13.angular_diameter_distance(img_z).value
        img_mpc = img_rads*ad_dist
        
    if img_z == -1.:
        print '*** ERROR: wrong redshift equal to -1.'

    img_arcmin = img_arcsec/60.
    sdss_scale = img_arcsec/img_pix
        
    return sdss_scale, img_arcsec, img_arcmin, img_mpc


#-------------------------------------------------------------------  
def distances_image(N_pix):
    
    X, Y = np.ogrid[0:N_pix, 0:N_pix]
    distances = np.sqrt((1.*X - N_pix/2.)**2 + (1.*Y - N_pix/2.)**2)
    
    return distances


##################################################################################################################
### Image manipulation utilities:
##################################################################################################################

#-------------------------------------------------------------------  
def crop_image(image_data, img_mpc):
    
    x_pix, y_pix = image_data.shape
        
    if (x_pix/y_pix > 1.1) or (y_pix/x_pix > 1.1):
        print '***WARNING: image ratio > 1.1'
        
    long_pix, short_pix = np.sort([x_pix, y_pix])
    img_mpc_size = img_mpc*(short_pix/long_pix)
    
    diff = np.abs(x_pix - y_pix)
    
    if x_pix > y_pix:
        image_data = image_data[diff/2 + diff%2 : x_pix - diff/2,]
    elif y_pix > x_pix:
        image_data = image_data[:,diff/2 + diff%2 : y_pix - diff/2]

    return image_data, img_mpc_size

        
#-------------------------------------------------------------------  
def zoom_image_data(image_data, final_size, img_mpc):

    x_pix, y_pix = image_data.shape
        
    if x_pix != y_pix:
        print '***WARNING: original fits dimensions: %i x %i' % (x_pix, y_pix)
        image_data, img_mpc_size = crop_image(image_data, img_mpc)
        x_pix, y_pix = image_data.shape
        print 'Fixed: final fits dimensions: %i x %i' % (x_pix, y_pix)
    else:
        img_mpc_size = img_mpc
            
    total_flux = image_data.sum()
    #print 'Total flux: ', total_flux
    
    image_data_zoom = ndimage.interpolation.zoom(image_data,(1.*final_size/x_pix, 1.*final_size/y_pix))
    
    zoom_flux = image_data_zoom.sum()
    #image_data_zoom = image_data_zoom*1.*total_flux/zoom_flux
    
    x_pix, y_pix = image_data_zoom.shape
    #print 'Zoomed fits dimensions: %i x %i' % (x_pix, y_pix)
    
    if (x_pix != final_size) | (y_pix != final_size):
        print '***WARNING! Final fits dimensions: %i x %i' % (x_pix, y_pix)

    #final_flux = zoom_image.sum()
    #print 'Zoomed flux: ', final_flux

    return image_data_zoom, img_mpc_size


#-------------------------------------------------------------------  
def check_for_nan(data, stop=False):

    if np.isnan(data).any() & stop:
        sys.exit("***ERROR: NaN found")
    elif np.isnan(data).any() & (not stop):
        print '***WARNING: NaN found'
        return True

    
#-------------------------------------------------------------------  
def convolve_image(image_data, size_kernel=5.):
    """http://astropy.readthedocs.io/en/latest/convolution/index.html"""
    
    kernel = Gaussian2DKernel(size_kernel)
    #kernel = AiryDisk2DKernel(size_kernel)
    #kernel = MexicanHat2DKernel(size_kernel)
    #kernel = Tophat2DKernel(size_kernel)
    
    result_image_data = convolve(image_data, kernel, boundary='extend')
    
    return result_image_data


#-------------------------------------------------------------------  
def apply_aperture_mask(image_data, aperture=1., inverse=False):

    lx, ly = image_data.shape
    X, Y = np.ogrid[0:lx, 0:ly]
    
    aperture_mask = (X - lx / 2) ** 2 + (Y - ly / 2) ** 2 > lx * ly / 4 * (aperture)**2
    
    if not inverse:
        image_data[aperture_mask] = 0.
    elif inverse:
        image_data[~aperture_mask] = 0.

    return image_data


#-------------------------------------------------------------------  
def fft_filtering(image_data, filter_aperture=0.5, inverse=False):
    
    # Take the fourier transform of the image:
    F1 = fftpack.fft2(image_data)
    # Now shift the quadrants around so that low spatial frequencies are in
    # the center of the 2D fourier transformed image:
    F2 = fftpack.fftshift(F1)
    # Apply filter to range of frequencies:
    F3 = apply_aperture_mask(F2, filter_aperture, inverse)
       
    # Reconstruct image:
    filt_image_data = np.abs(fftpack.ifft2(F3))
    
    return filt_image_data


##################################################################################################################
### Filtering utilities:
##################################################################################################################

#-------------------------------------------------------------------  
def create_antenna(img_pix, fwhm, display=False):
    """Creates FFT antenna."""

    distances = distances_image(img_pix)
    
    ## Antenna in real space
    sigma = (fwhm/2.0)/(np.sqrt(2.0*np.log(2.0)))
    antenna = np.exp(-distances**2/(2.0*(sigma**2))) 
    normalization = np.sum(antenna)
    #antenna = antenna/normalization
    
    ## FFT antenna
    fft_antenna = fftpack.fft2(antenna)
    fft_antenna = fft_antenna/(np.abs(fft_antenna[0,0]))
    
    ## Power of antenna
    ff = np.array(distances, dtype=np.int)
    nf = np.int(np.sqrt(2.0*((img_pix/2.0)**2)))
    power_antenna = np.arange(0., nf, 1.)

    for i in range(nf):
        p_mask = ff==i
        power_antenna[i] = np.mean(np.abs(fft_antenna[p_mask])**2.)
    
    if display:
        
        print '\nAntenna:'
        print 'fwhm      : %i pixels' % fwhm
        print 'Npix_side : %i pixels' % img_pix

        fig = plt.figure(figsize=(10,4))
        
        ax1 = fig.add_subplot(131)
        ax1.imshow(antenna)
        ax1.set_title('antenna')
        
        ax2 = fig.add_subplot(132)
        ax2.imshow(np.abs(fft_antenna))
        ax2.set_title('FFT antenna')
        
        ax3 = fig.add_subplot(133)
        ax3.plot(range(nf), power_antenna)
        #ax3.set_xscale('log')
        ax3.set_yscale('log')
        ax3.set_title('antenna power')
        
        fig.tight_layout()
        plt.show()
        plt.close(fig)
        
    return antenna, fft_antenna, power_antenna


#-------------------------------------------------------------------  
def create_MHW(img_pix, MHW_scale, display=False):
    """Creates Mexican Hat Wavelet (MHW)."""
    
    distances = distances_image(img_pix)/MHW_scale
    
    ## MHW in real space
    MHW = (2.0 - distances**2)*np.exp((-distances**2)/2.0)
    normalization = np.sum(MHW)
    MHW = MHW/normalization
    
    ## FFT MHW
    fft_MHW = fftpack.fft2(MHW)
    fft_MHW = fftpack.fftshift(fft_MHW)
    #fft_MHW = fft_MHW/(np.abs(fft_MHW[0,0]))
    
    ## Power of MHW
    ff = np.array(distances, dtype=np.int)
    nf = np.int(np.sqrt(2.0*((img_pix/2.0)**2)))
    power_MHW = np.arange(0., nf, 1.)

    for i in range(nf):
        p_mask = ff==i
        power_MHW[i] = np.mean(np.abs(fft_MHW[p_mask])**2.)
    
    if display:
        
        print '\nMHW:'
        print 'scale     : %i pixels' % MHW_scale
        print 'Npix_side : %i pixels' % img_pix

        fig = plt.figure(figsize=(10,4))
        
        ax1 = fig.add_subplot(131)
        ax1.imshow(MHW)
        ax1.set_title('MHW')
        
        ax2 = fig.add_subplot(132)
        ax2.imshow(np.abs(fft_MHW))
        ax2.set_title('FFT MHW')
        
        ax3 = fig.add_subplot(133)
        ax3.plot(range(nf), power_MHW)
        #ax3.set_xscale('log')
        ax3.set_yscale('log')
        ax3.set_title('MHW power')
        
        fig.tight_layout()
        plt.show()
        plt.close(fig)
        
    return MHW, fft_MHW, power_MHW


##################################################################################################################
### Image normalization utilities:
##################################################################################################################

#-------------------------------------------------------------------  
def norm_image_absolute(image_data, norm_down=1., norm_up=np.inf):
       
    image_data = image_data - image_data.min() + norm_down
    image_data[image_data > norm_up] = norm_up     
        
    return image_data

    
##################################################################################################################
### Visualization utilities:
##################################################################################################################

#-------------------------------------------------------------------  
def plot_image_data(image_data, lognorm=True):
    
    my_dpi=100
    
    lx, ly = image_data.shape
    fig = plt.figure(figsize=(1.*lx/my_dpi, 1.*ly/my_dpi), dpi=my_dpi)
    
    ax = fig.add_axes([0.,0.,1.,1.])
    plt.axis('off')
    
    if lognorm:
        plt.imshow(image_data, cmap='gray', norm=LogNorm())
    elif not lognorm:
        plt.imshow(image_data, cmap='gray')
        
    return fig


#-------------------------------------------------------------------  
def plot_spectrum_data(image_data):
                  
    figsizex, figsizey = 8, 8
    fig = plt.figure(figsize=(figsizex, figsizey))
    ax = fig.add_axes([0.,0.,1.,1.])
    
    F1 = fftpack.fft2(image_data)
    F2 = fftpack.fftshift(F1)
    psf_2D = np.abs(F2)**2
    ax.imshow(np.log10(psf_2D))
        
    return fig


#-------------------------------------------------------------------  
def plot_histogram(image_data, lognorm=False, norm_min=1., norm_max=20.):
    
    figsizex, figsizey = 5, 4
    fig = plt.figure(figsize=(figsizex, figsizey))
    ax = fig.add_subplot(111)
    
    #bin_type = 'scott'
    bin_type = 'freedman'
    
    if not lognorm:
        hist(image_data, bins=bin_type, histtype='stepfilled', alpha=0.2, normed=True)
    elif lognorm:
        hist(np.log(image_data), bins=bin_type, histtype='stepfilled', alpha=0.2, normed=True)
    ax.set_xlabel('DI')
    ax.set_ylabel('P(DI)')
    plt.xlim([norm_min - 0.5, norm_max + 0.5])
    plt.title('Light histogram')
    plt.tight_layout()
        
    return fig