In [1]:
#=================================================================================================================
# Image selection:
#=================================================================================================================

def locate_closest_frame(time):
    return np.int(np.round(time/config.t_per_frame))


def locate_closest_image(time):
    return np.int(np.round((time - config.t_offset_start)/config.t_per_image))


#-----------------------------------------------------------------------------------------------------------------
def locate_positions(config, sound):
    
    positions = []
    pos_start = locate_closest_image(sound.start)
    jump = locate_closest_image(sound.start + sound.rate) - pos_start
    
    for i in range(sound.nbeats):
        
        position = pos_start + i*jump
        if position < config.n_images:
            positions.append(position)
        else:
            break
            
    return np.array(positions)



#-----------------------------------------------------------------------------------------------------------------
def order_images(sample_meta, config, sounds):

    print('\nCreating sequence...\n')
    
    sequence_ids = np.zeros(config.n_images)

    # meta = [clu_id, wise_band, np.mean(image_data), np.std(image_data)]   

    sequence_order = np.argsort(wise_meta[0:,3])
    ordered_ids = wise_meta[0:,0][sequence_order]

    positions0 = locate_positions(config, sounds[0])
    positions1 = locate_positions(config, sounds[1])
    not_repeated = np.in1d(positions0, positions1, invert=True)
    positions0 = positions0[not_repeated]

    sound0_ids = ordered_ids[-(len(positions0) + len(positions1)): -len(positions1)]
    sound1_ids = ordered_ids[-len(positions1):]
    base_images = config.n_images - (len(sound0_ids) + len(sound1_ids))

    sequence_ids[positions0] = sound0_ids
    sequence_ids[positions1] = sound1_ids

    base_positions = np.argwhere(sequence_ids == 0.).flatten()
    sequence_ids[base_positions] = ordered_ids[0:len(base_positions)]
    
    print('Sequence created\n')
    print('*'*20)

    return sequence_ids


#-----------------------------------------------------------------------------------------------------------------
def find_eff_treshold(sample_images, dist_images):
    """
    Returns the effect treshold nedded to obtain a separation of
    dist_images between images that trigger effects in the sequence
    of images considered.
    """
       
    print('\nComputing sample images statistics...\n')
    
    n_images = len(sample_images)
    means = np.array([np.mean(np.array(x, dtype=np.float)) for x in sample_images])
    stddevs = np.array([np.std(np.array(x, dtype=np.float)) for x in sample_images])
    
    print('number of images found: %i' % n_images)
    print('mean light in images: %.5f' % np.mean(means))
    print('mean std dev in images: %.5f' % np.mean(stddevs))
    print('std dev of std devs in images: %.5f' % np.std(stddevs))
    
    n_images_treshold = (0.8*n_images)/dist_images
    
    eff_treshold = 0.
    n_images_eff = 0
    
    while True:
        eff_treshold += 0.1
        n_images_eff = (stddevs > eff_treshold).sum()
        if n_images_eff < n_images_treshold:
            break
    
    print('\nEffect treshold set at std dev = %.2f\n' % eff_treshold)
    
    print('*'*20)
        
    return eff_treshold


#-----------------------------------------------------------------------------------------------------------------  
def fix_start(sample, sample_images, sample_meta, img_pix, img_mpc, wise_band, fixed_ids):
    """
    Adds 4 fixed images so several realizations start and end the same way.
    Returns:
    1) Sequence of images with extra loop images.
    2) Meta data of images with extra mata data coming from fixed images.
    """
        
    if len(fixed_ids)==0:
        pass
    elif len(fixed_ids)>0:
        
        print('\nAdding fixed images...\n')
            
        for i, clu_id in enumerate(fixed_ids):
    
            clu = Cluster(sample, clu_id)
            img_scale, img_arcsec, img_arcmin, img_mpc = scale_image('mpc', clu.photo_z, img_pix, img_mpc)


            print('%i (%06d)...' % (i, clu_id),)

            download_wise_fits(clu_id, clu.ra, clu.dec, img_arcsec, wise_band)
            image_data = get_wise_image_data(clu_id, img_arcsec, wise_band)
            image_data, img_mpc_size, bad_ratio = zoom_image_data(image_data, img_pix, img_mpc)
            image_data = norm_image_absolute(image_data, norm_down=1.)

            ## Add image to set of images:
            if check_for_nan(image_data, stop=False) or bad_ratio:
                print('[Discarded]',)
                continue
            else:
                sample_images[i] = image_data
                sample_meta[i] = ([clu_id, clu.rich, clu.photo_z, np.mean(image_data), np.std(image_data)])
                
        print('\n\nWISE cluster images processing finished\n')
        print('*'*20)
        
    else:
        sys.exit("***ERROR on fixed_ids value.")
        
    
    return sample_images, sample_meta


#-----------------------------------------------------------------------------------------------------------------  
def add_extra_loop_images(sample_images, sample_meta):
    """
    Adds 4 extra images to close the sequence into a loop.
    Returns:
    1) Sequence of images with extra loop images.
    2) Meta data of images with extra mata data coming from extra loop images.
    """
    
    ## Append 4 extra images at the end to make loop:
    for k in range(4):
        sample_images = np.concatenate((sample_images, np.array([sample_images[k]])), axis=0)

    ## Final result:
    check_for_nan(sample_images, stop=True)
    sample_meta = np.concatenate((sample_meta[2:], sample_meta[0:2]))

    return sample_images, sample_meta



#=================================================================================================================
# Interpolation:
#=================================================================================================================

#-----------------------------------------------------------------------------------------------------------------
def create_interp_frames(sample_ids, config, wise_band, renew=False):
    """
    Routine to create the base frames needed to produce the video.
    Creates an 3º order interpolation between images to obtain extra frames.
    """
    
    frames_dir = '%s/interp_frames' % img_dir
    
    ## Delete previous frames
    if os.path.exists(frames_dir):
        if renew:    
            shutil.rmtree(frames_dir)
            os.makedirs(frames_dir)
    elif not os.path.exists(frames_dir):
        os.makedirs(frames_dir)
    
    sample_images = []
    
    print('\nLoading images...')
    
    ## Load images:
    for clu_id in sample_ids:
        image = Wise_Image(clu_id, wise_band)
        image.load(config)
        image_data = image.img_data
        sample_images.append(image_data)
        del image
    
    ## Check images:
    for im in sample_images:
        x_pix, y_pix = im.shape
        if x_pix != y_pix:
            sys.exit('***ERROR: input image dimensions: %i x %i' % (x_pix, y_pix))

    ## Define interpolator:
    
    def interpolator(t_image_vals, interp_sample):
        
        # Monotonic cubic interpolation
        if 0:
            if config.log_interp:
                interp_f = PchipInterpolator(t_image_vals, np.log(interp_sample), axis=0)
            elif not config.log_interp:
                interp_f = PchipInterpolator(t_image_vals, interp_sample, axis=0)

        # Cubic interpolator
        if 1:
            if config.log_interp:
                interp_f = interp1d(t_image_vals, np.log(interp_sample), axis=0, kind=3)
            elif not config.log_interp:
                interp_f = interp1d(t_image_vals, interp_sample, axis=0, kind=3)
                
        return interp_f
                
    print('\nCreating frames...\n')
                
    ## Compute interpolation:
    
    fpi = config.fpi
    n_images = config.n_images
    n_frames = config.n_frames
    
    t_start1 = config.frames_offset_start
    t_start2 = config.frames_black_start

    t_image_vals = np.arange(t_start1, t_start1 + n_images*fpi, fpi)
    t_frame_vals = np.arange(t_start2, n_frames, 1.)
    
    #print(t_image_vals)
    #print(t_frame_vals)
    
    im_per_step = 10
    inter_steps = np.int(np.ceil(len(sample_images)/im_per_step))
    
    for i in range(inter_steps):
        
        if i == 0:
            t_image_step = t_image_vals[0:im_per_step + 1]
            images_step = sample_images[0:im_per_step + 1]
            t_frame_step = np.arange(t_image_step[0], t_image_step[-2], 1.)
            
        if i > 0 and i != inter_steps:
            t_image_step = t_image_vals[i*im_per_step - 2:(i+1)*im_per_step + 1]
            images_step = sample_images[i*im_per_step - 2:(i+1)*im_per_step + 1]
            t_frame_step = np.arange(t_image_step[1], t_image_step[-2], 1.)
            
        if i == inter_steps - 1:
            t_image_step = t_image_vals[i*im_per_step - 2:]
            images_step = sample_images[i*im_per_step - 2:]
            t_frame_step = np.arange(t_image_step[1], t_image_step[-1] + 1, 1.)
    
        #print('step nº', i, ':',)

        #print(t_image_step)
        #print(t_frame_step)
        
        interp_f = interpolator(t_image_step, images_step)

        ## Produce interpolated frames:

        for t in t_frame_step:

            comp = 100.*t/len(t_frame_vals)
            print('\rFrame %04d - %.2f completed' % (t, comp), end='')
            frame_file = '%s/interp_frames/frame_%04d.npy' % (img_dir, t)

            if os.path.exists(frame_file) and not renew:
                continue
            else:
                if config.log_interp:
                    img_t = np.exp(interp_f(t))
                elif not config.log_interp:
                    img_t = interp_f(t)
                np.save(frame_file, np.array(img_t, dtype=np.float16))

    print('\r 100.00 completed\n\nSequence frames created\n')
    print('*'*20)

    
    

#-----------------------------------------------------------------------------------------------------------------
def define_eff_sequence(fsigma_base, fsigma_eff, display=False):
    """
    Reads the frames produced and defines the effect sequence.
    fsigma_base: gives the duration of the "base" sequence.
    fsigma_eff: gives the duration of the "effects" sequence.
    The "fsigma"s are given as the std dev (in terms of the fraction of the length
    of the sequence) of a gaussian centered in the middle of the sequence.
    Returns:
    1) The "base" sequence (amplitude of normal images).
    2) The "effects" sequence (amplitude of images that trigger effects).
    """
    
    frames_dir = '%s/frames' % img_dir
    n_frames = len([name for name in os.listdir(frames_dir) if os.path.isfile(os.path.join(frames_dir, name))])
    
    def get_image_stddev(f):
        f_file = '%s/frames/frame_%05d.npy' % (img_dir, f)
        image_f = np.load(f_file)
        img_std = np.std(np.array(image_f, dtype=np.float))
        return img_std
    
    def gaussian_distribution(f, fsigma):
        A = 1.
        mu = n_frames/2.
        sigma = n_frames*fsigma
        return A*np.exp(-(f-mu)**2/(2.*sigma**2))
    
    f_vector = np.arange(n_frames)
    image_stddevs = np.array([get_image_stddev(f) for f in f_vector])
    image_eff_weights = image_stddevs
    eff_min = image_eff_weights.min()
        
    eff_sequence = []
    w_eff_sequence = []
    for f in f_vector:
        eff_seq = image_eff_weights[f]
        image_eff = (eff_seq - eff_min)
        w_image_eff = image_eff*gaussian_distribution(f, fsigma_eff)
        eff_sequence.append(image_eff)
        w_eff_sequence.append(w_image_eff)
        
    base_sequence = 1.00001 - gaussian_distribution(f_vector, fsigma_base)
    eff_sequence = np.array(eff_sequence)
    w_eff_sequence = np.array(w_eff_sequence)
    
    if display:
        fig = plt.figure(figsize=(15,5))
        ax = fig.add_subplot(111)    
        plt.title('Images sequence')
        plt.plot(f_vector, eff_sequence, 'k-', label='eff sequence')
        plt.plot(f_vector, w_eff_sequence, 'g:', label='w eff sequence')
        plt.plot(f_vector, base_sequence, 'r-', label='base level')
        plt.xlim([0., f_vector[-1]])
        plt.ylim([0., eff_sequence.max()])
        plt.legend(loc=1)
        plt.show()
        plt.close(fig)

    return base_sequence, w_eff_sequence, eff_sequence


#-----------------------------------------------------------------------------------------------------------------
def create_video_pngs(config, effects_data, v_start=0, v_end=-1, renew=True):
    """
    Creates the pngs generated by the interpolated frames and the effects defined.
    """    
    
    #n_frames_dir = len([name for name in os.listdir(interp_dir) if os.path.isfile(os.path.join(interp_dir, name))])
    #print('\n%i frames found in "Interp_frames" directory.'
    
    f_start = np.int(np.floor(v_start/config.t_per_frame))
    
    if v_end==-1:
        f_end = config.n_frames
    elif v_end > 0.:
        f_end = np.int(np.ceil(v_end/config.t_per_frame))
        
    n_frames = f_end - f_start
        
    print('Producing frames: [%04d - %04d]' % (f_start, f_end))
        
    ## Delete previous frames
    frames_dir = '%s/frames_final' % vid_dir
    if not os.path.exists(frames_dir):
        #shutil.rmtree(frames_dir)
        os.makedirs(frames_dir)   
    
    print('\nGenerating video frames...\n')
    
    for f_num in range(f_start, f_end):
                
        comp = 100.*(f_num - f_start)/(n_frames)
        print('\rFrame: %04d - %.2f completed' % (f_num, comp), end='')
        
        png_file = '%s/frames_final/frame_%04d.png' % (vid_dir, f_num)
        if not renew and os.path.exists(png_file):
            continue
        else:
            frame = Frame(config, f_num)

            if (f_num < config.frames_black_start):
                frame.set_black_frame()
                frame.save_png(test=False)
            else:
                frame.load()
                frame.process_frame(config, effects_data)
                frame.save_png(test=False)
            
    print('\r100.00 completed\n')
    print('*'*20)
    

#-----------------------------------------------------------------------------------------------------------------
class Frame(object):
    """
    Frame object.
    """
    
    def __init__(self, config, number):
        
        self.number = number
        
        self.npy_file = '%s/interp_frames/frame_%04d.npy' % (img_dir, number)
        if number >= 0:
            self.png_file = '%s/frames_final/frame_%04d.png' % (vid_dir, number)
        elif number < 0:
            self.png_file = '%s/frames_final/frame_-%04d.png' % (vid_dir, -number)
        
         # Color map of images shown in video ['gray', 'bone', 'afmhot', 'CMRmap']: [color_map = 'bone']
        self.img_cmap = 'bone'
        # Shown images in video in log scale? (True recommended):
        self.img_log_scale = True 
        # Max value scale of images in video: [v_max = 100.]
        self.img_v_max = 10.
        # Size of the black border frame:
        self.width_border = 10.
        
        self.figsizex = 1.*config.video_pixx/config.my_dpi
        self.figsizey = 1.*config.video_pixx/config.video_ratio/config.my_dpi
        
                    
    def set_img_cmap(self, img_cmap):
        self.img_cmap = img_cmap
                
    def set_img_v_max(self, img_v_max):
        self.img_v_max = img_v_max
                
    def set_image_data(self, image_data):
        self.image_data = image_data.astype(np.float, copy=False)
        
    def set_black_frame(self):
        image_data = np.ones((np.int(9*config.video_ratio), 9))
        self.set_image_data(image_data)
        
    def load(self):
        image_data = np.load(self.npy_file)
        self.set_image_data(image_data)
        
    def shape(self):
        return np.shape(self.image_data)
            
    def save_png(self, test=False):
            
        fig = plt.figure(figsize=(self.figsizex, self.figsizey), dpi=config.my_dpi)
        ax = fig.add_axes([0.,0.,1.,1.])
        ax.set_facecolor("black")
        ax.get_xaxis().set_ticks([])
        ax.get_yaxis().set_ticks([])

        v_min = 1.
        
        if self.img_log_scale:
            ax.imshow(self.image_data, cmap=self.img_cmap, norm=LogNorm(), vmin=v_min, vmax=self.img_v_max)
        elif not self.img_log_scale:
            ax.imshow(self.image_data, cmap=self.img_cmap, vmin=v_min, vmax=self.img_v_max)

        if test:
            test_file='%s/test_frame.png' % project_dir
            plt.show()
            fig.savefig(test_file, facecolor='black')
        else:
            #print(self.shape)
            #plt.show()
            fig.savefig(self.png_file, facecolor='black')
        fig.clf()
        plt.close()  

    def process_frame(self, config, effects_data):
        
        self.image_data = effect_base(self.image_data)
        self.image_data = effect_fft_inv(self.image_data, effects_data['fft_inv'][self.number], 1.5)
        self.image_data = effect_fft(self.image_data, effects_data['fft'][self.number], 1.5)
        self.image_data = effect_blackwhitehole(self.image_data, effects_data['platos'][self.number], effects_data['bateria'][self.number])
        self.image_data = effect_singularity(self.image_data, effects_data['final'][self.number], 1.8)
        self.image_data = effect_fade_to_black(self.image_data, effects_data['ftb'][self.number])
        self.image_data = display_image(config, self.image_data)
        self.image_data = effect_prepare_border(self.image_data, self.width_border)
    