In [None]:
#=================================================================================================================
# Directories utilities:
#=================================================================================================================
        
#-----------------------------------------------------------------------------------------------------------------
def generate_new_images(delete_fits=False, delete_npys=False, delete_pngs=False):
    """
    Restarts the folders needed to store images.
    """

    fits_dir = '%s/FITS' % img_dir
    if os.path.exists(fits_dir):
        if delete_fits:
            shutil.rmtree(fits_dir)
            os.makedirs(fits_dir)
    else:
        os.makedirs(fits_dir)
        
    npys_dir = '%s/NPY' % img_dir
    if os.path.exists(npys_dir):
        if delete_npys:
            shutil.rmtree(npys_dir)
            os.makedirs(npys_dir)
    else:
        os.makedirs(npys_dir)
        
    pngs_dir = '%s/PNG' % img_dir
    if os.path.exists(pngs_dir):
        if delete_pngs:
            shutil.rmtree(pngs_dir)
            os.makedirs(pngs_dir)
    else:
        os.makedirs(pngs_dir)
    

#=================================================================================================================
# Image manipulation utilities:
#=================================================================================================================
       
#-----------------------------------------------------------------------------------------------------------------
def zoom_image_data(image_data, final_size, img_mpc, verbose=False):
    """
    Zooms images to coomon size, if possible.
    Returns:
    1) image zoomed.
    2) image size in mpc after zooming.
    3) problem_ratio flag.
    """
        
    crop_pix = 2
    
    def crop_image(image_data, img_mpc):
    
        problem = False
        x_pix, y_pix = image_data.shape

        if (1.*x_pix/y_pix > 1.1) or (1.*y_pix/x_pix > 1.1):
            if verbose:
                print('*** WARNING: image ratio > 1.1 ***',)
            problem = True
            
        else:
            #long_pix, short_pix = np.sort([x_pix, y_pix])
            #img_mpc_size = img_mpc*(short_pix/long_pix)

            diff = int(np.abs(x_pix - y_pix))

            if x_pix > y_pix:
                image_data = image_data[int(diff/2) + diff%2 : x_pix - int(diff/2),]
            elif y_pix > x_pix:
                image_data = image_data[:,int(diff/2) + diff%2 : y_pix - int(diff/2)]

        return image_data, problem #,img_mpc_size

    
    x_pix, y_pix = image_data.shape
        
    if x_pix != y_pix:
        if verbose:
            print('*** WARNING: original fits dimensions: %i x %i' % (x_pix, y_pix))
        image_data, problem_ratio = crop_image(image_data, img_mpc)
        if problem_ratio:
            return image_data, problem_ratio
        x_pix, y_pix = image_data.shape
        if verbose:
            print('Fixed: final fits dimensions: %i x %i' % (x_pix, y_pix))
    else:
        problem_ratio = False
        #img_mpc_size = img_mpc
            
    #total_flux = image_data.sum()
    #print('Total flux: ', total_flux
    zoom_size = final_size + 2*crop_pix
    image_data_zoom = ndimage.interpolation.zoom(image_data,(1.*zoom_size/x_pix, 1.*zoom_size/y_pix))
    image_data_zoom = image_data_zoom[crop_pix:-crop_pix, crop_pix:-crop_pix]
    
    #zoom_flux = image_data_zoom.sum()
    #image_data_zoom = image_data_zoom*1.*total_flux/zoom_flux
    
    x_pix, y_pix = image_data_zoom.shape
    if verbose:
        print('Zoomed fits dimensions: %i x %i' % (x_pix, y_pix))
        
    if (x_pix != final_size) | (y_pix != final_size):
        if verbose:
            print('*** WARNING! Final fits dimensions: %i x %i' % (x_pix, y_pix))

    #final_flux = zoom_image.sum()
    #print('Zoomed flux: ', final_flux)

    return image_data_zoom, problem_ratio #,img_mpc_size,


#-----------------------------------------------------------------------------------------------------------------
def check_for_nan(image_data, stop=False):
    """
    Checks if image contains NaNs.
    """
    
    if np.isnan(image_data).any() & stop:
        sys.exit("***ERROR: NaN found")
    elif np.isnan(image_data).any() & (not stop):
        print('***WARNING: NaN found')
        return True

    
    
#=================================================================================================================
# Visualization utilities:
#=================================================================================================================

#-----------------------------------------------------------------------------------------------------------------
def plot_image_data(image_data, lognorm=True, test=False):
    """
    Returns figure of image with vmin=1. and vmax=50.
    """
                  
    figsizex, figsizey = 8, 8
    fig = plt.figure(figsize=(figsizex, figsizey))
    ax = fig.add_axes([0.,0.,1.,1.])
    
    if lognorm:
        plt.imshow(image_data, cmap='gray', norm=LogNorm(), vmin=1., vmax=50.)
    elif not lognorm:
        plt.imshow(image_data, cmap='gray', vmin=1., vmax=50.)
        
    if test:
        fig.savefig('%s/test.png' % vid_dir)
        fig.clf()
    else:
        return fig


#=================================================================================================================
# Image selection:
#=================================================================================================================

#-----------------------------------------------------------------------------------------------------------------
def generate_images(sample, config, wise_band, clu_images=-1, produce_png=False):
    """
    Routine to obtain the images genereated from sample.
    Returns:
    Meta data of the images contained in the sequence.
    """

    sample_meta = []

    if clu_images == -1:
        n_images = len(sample['id'])
    else:
        n_images = clu_images
        
    ## Run code:    
    print('\nProcessing cluster nº (redMaPPer id, WISE band):\n')

    cluster_ids = sample['id'].copy()
    np.random.shuffle(cluster_ids)
        
    c_count = 0
    i = 0
    
    while c_count < n_images:
            
        clu_id = cluster_ids[i]
        clu = Cluster(sample, clu_id)
        
        i += 1

        img_scale, img_arcsec, img_arcmin, img_mpc = scale_image('mpc', clu.photo_z, config.img_pix, config.img_mpc)
    
        print('\r%i (%06d, W%i)' % (c_count+1, clu_id, wise_band), end='\r')
            
        image = Wise_Image(clu_id, wise_band)
        try:
            image.load(config)
        except:
            download_wise_fits(clu_id, clu.ra, clu.dec, img_arcsec, wise_band)
            image_data = get_wise_image_data(clu_id, img_arcsec, wise_band)
            image_data, img_bad_ratio = zoom_image_data(image_data, config.img_pix, config.img_mpc)
            if check_for_nan(image_data, stop=False) or img_bad_ratio:
                #print(' --> [Discarded]', end='')
                continue
            else:
                norm_down = 1.
                image_data = image_data - image_data.min() + norm_down
                image.save(config, image_data)
                            
        ## Add metadata of image:
        c_count += 1
        sample_meta.append([clu_id, wise_band, image.img_mean, image.img_std])

        if produce_png:
            image.save_png()

    print('\n\nWISE cluster images processing finished (%i images)\n' % c_count)
    print('*'*20)

    #sample_images = np.array(sample_images)
    #check_for_nan(sample_images, stop=True)
    sample_meta = np.array(sample_meta)
    
    ## Sort by ID
    if 0:
        sort_order = np.argsort(sample_meta[:,0])
        for i in [0,1,2,3]:
            sample_meta[:, i] = sample_meta[:, i][sort_order]

    return sample_meta


#-----------------------------------------------------------------------------------------------------------------
def display_image(config, image_data):

    ly, lx = image_data.shape
    display_pixy = np.int(lx*config.video_pixy/config.video_pixx)
    diffy = int((ly - display_pixy)/2)
    display_img = image_data[diffy : -diffy, :]
    
    return display_img

    
#-----------------------------------------------------------------------------------------------------------------
class Wise_Image(object):
    """
    WISE image object.
    """
    
    def __init__(self, clu_id, wise_band):
        
        self.clu_id = clu_id
        self.wise_band = wise_band
        self.npy_file = '%s/NPY/IMG_%06d_WISE,_band_%i.npy' % (
            img_dir, self.clu_id, self.wise_band)
        self.png_file = '%s/PNG/IMG_%06d_WISE,_band_%i.png' % (
            img_dir, self.clu_id, self.wise_band)        
        
    def set_image_data(self, config, image_data):
        display_img = display_image(config, image_data)
        self.img_mean = np.mean(display_img)
        self.img_std = np.std(display_img)
        self.img_data = image_data
        #self.img_data = image_data.astype(np.float16, copy=False)
        
    def load(self, config):
        image_data = np.load(self.npy_file)
        self.set_image_data(config, image_data)
        
    def save(self, config, image_data):
        #np.save(self.npy_file, np.array(image_data, dtype=np.float16))
        #np.save(self.npy_file, np.array(image_data, dtype=np.float64))
        np.save(self.npy_file, image_data)
        self.set_image_data(config, image_data)
            
    def save_png(self):
        if not os.path.isfile(self.png_file):
            fig = plot_image_data(self.img_data)
            fig.savefig(self.png_file)
            fig.clf()
            plt.close()
    
    