In [4]:
import numpy as np
import h5py
import matplotlib.pyplot as plt
import os
from PIL import Image
from scipy.ndimage.interpolation import rotate
from skimage.transform import resize
 
TEMPLATE = "the %s is %s"
mvmt2str={0:'moving left and right', 1:'moving up and down', 2:'rotating clockwise', 3:'rotating counterclockwise', 
          4:'rotating inplace clockwise', 5:'rotating inplace counterclockwise'}
word2num_dict = {' ':0, '0':21, '1':1, '2':2, '3':3, '4':4, '5':5, '6':6, '7':7, '8':8, '9':9, 
              'the': 10, 'digit': 11, 'and': 12, 'is':13, 'are':14, 'moving':15, 'up':16, 'down':17, 
              'left':18, 'right': 19, '.':20}
 
def create_gif_caption_dataset(data, labels, num_gifs, moving_elements_per_gif, allowed_elements, allowed_movements, 
                               num_frames=10, gif_w=64, gif_h=64, period=[10,15], save=True, save_memory= True, tol=4, 
                               output_file='./gifcap.h5', template=TEMPLATE, elem2str={}, mvmt2str=mvmt2str, word2num_dict=word2num_dict):
    '''Function to generate gifs for any element set. num_gifs GIFs will be produced, 
    with moving_elements_per_gif elements per gif, moving in one of the movements present in `allowed_movements`. '''
     
    # First, make sure that we have a dictionary with all words needed to encode our captions. Expand the existing word2num to add
    # Possible movements or elements added by the user.
    word2num_dict = expand_dictionary(word2num_dict, list(elem2str.values())+list(mvmt2str.values()))
     
    # Get the digits that will be used for each gif. 'digits_to_use' will be 
    # a n_gifs x moving_elements_per_gif array of random numbers. If working with MNIST, the digits will be sampled equally 
    # to cover the whole mnist dataset.
    elements = get_random_elements(num_gifs, moving_elements_per_gif, allowed_elements, elem2str)
     
    # Movements will be a n_gifs x n_digits array with a movement type for each digit in each gif.
    # The movements provided in allowed_movements will be equally represented in the final array.
    movements = get_random_movements(num_gifs, moving_elements_per_gif, allowed_movements, mvmt2str)
         
    # generate_captions will take the given digits and movements and form the captions for each of them.
    captions = generate_captions(elements, movements, template)
 
    # generate_gifs will take take the given digits and movements and generate the frame sequence.
    gifs = generate_gifs(data, labels, elements, movements, num_frames, gif_w, gif_h, elem2str=elem2str, period=period)
     
    # Save
    if save:
        save_gifs(output_file, gifs, captions)
     
    return gifs, captions

def save_gifs(output_file, gifs, captions):
    with h5py.File(output_file,'w') as hf:
        print('captions',captions)
        hf.create_dataset('gifs', data=gifs)
        hf.create_dataset('captions', data=captions)
        words = np.array(list(word2num_dict.keys()), dtype='S')
        values = np.array(list(word2num_dict.values()))
        hf.create_dataset('dict_keys', data=words)
        hf.create_dataset('dict_values', data=values)
        
def load_data_mnist(mnist_file_path):
    f = h5py.File(mnist_file_path)
    train_data = np.array(f['train']['inputs'])
    train_labels = np.array(f['train']['targets'])
    val_data = np.array(f['test']['inputs'])
    val_labels = np.array(f['test']['targets']) 
    data = np.concatenate((train_data,val_data), axis = 0)
    labels = np.concatenate((train_labels,val_labels), axis = 0)
    elem2str={0:'0', 1:'1', 2:'2', 3:'3', 4:'4', 5:'5', 6:'6', 7:'7', 8:'8', 9:'9'}
    return data, labels, elem2str
 
def load_data_icons(icon_folder):
     
    img_names = os.listdir(icon_folder)
    labels = []
    icons = []
    for name in img_names:
        labels.append(name.split('_')[0])
        image = np.array( Image.open(os.path.join(icon_folder,name)).convert('LA'), dtype=np.uint8)[:,:,1]
        icons.append(image)
 
    # Generate element dictionary
    unique_labels = np.unique(labels)
    elem2str_icons = {}
 
    # Append label text to word2num
    for i in range(len(unique_labels)):
        elem2str_icons[i] = unique_labels[i]
         
    return icons, labels, elem2str_icons

def get_random_movements(n_gifs, n_elem_per_gif, movement_list, mvmt2str):
    '''Return a n_gifs x n_elements_per_gif array of movement labels from the movement_list array'''   
    return get_random_labels(n_gifs, n_elem_per_gif, movement_list, mvmt2str)
 
def get_random_elements(n_gifs, n_elem_per_gif, element_list, elem2str):
    '''Return a n_gifs x n_elements_per_gif array of element labels from the element_list array'''
    return get_random_labels(n_gifs, n_elem_per_gif, element_list, elem2str)
 
def get_random_labels(N, M, label_list, elem2str):
    '''Return a NxM matrix of labels from label_list array. The final matrix will have an equilibred proportion of elements'''
     
    # Check if the label list is a list of strings or integers
    if isinstance(label_list[0], str):
        str2num = create_reverse_dictionary(elem2str)
        label_list = [str2num[l] for l in label_list]
     
    elems = np.zeros((N,M), dtype=np.uint8)
    ordered_elems = list(label_list)*int(np.floor(N/len(label_list)))
    remaining_spots_to_fill = N-len(ordered_elems)
    if remaining_spots_to_fill > 0:
        ordered_elems.extend(label_list[:remaining_spots_to_fill])
 
    for i in range(M):
        elems[:,i] = np.random.permutation(ordered_elems)
         
    return elems
 
def generate_captions(elems, mvmts, template):
    '''Generate captions for the element and movement labels received in elems and mvmts'''
     
    if isinstance(elems, list):
        elems = np.array(elems)
     
    captions = []
    caption_matrix = np.zeros((len(elems), MAX_CAPTION_LENGTH))
 
    for i in range(len(elems)):
        # Generate the caption for the first element
        captions.append( template % (elem2str[elems[i,0]], mvmt2str[mvmts[i,0]]))
        # Generate the captions for the following elements, with 'and' as connector
        for j in range(len(elems[i,:])-1):
            captions[i] += ' and '
            captions[i] += template % (elem2str[elems[i,j+1]], mvmt2str[mvmts[i,j+1]])
             
        # Appending final period
        captions[i]+=' .'
        caption_matrix[i]= sent2matrix(captions[i])
             
    return np.array(caption_matrix)
 
def expand_dictionary(word2num_dict, word_and_sent_array):
    '''Expand the given dictionary by looking at each element of word_and_sent_array: if its a word, append it to the dict with a new label,
    And if its a sentece with multiple words, parse each word and add it to the dict'''
    previous_dict_max = max(word2num_dict.values())
    i=1
     
    for s in word_and_sent_array:
        wds = s.split(' ')
        for w in wds:
            if w not in word2num_dict.keys():
                word2num_dict[w] = i+previous_dict_max
                i+=1
             
    return word2num_dict
 
def sent2matrix(sentence, word2num_dict=word2num_dict):
    words = sentence.split()
    m = np.zeros(MAX_CAPTION_LENGTH)
    for i in range(len(words)):
        m[i]=word2num_dict[words[i]]
    return m
 
def matrix2sent(matrix, mat2sent_dict):
    text = ""
    for i in range(matrix.shape[0]):
        text = text + " " + mat2sent_dict[matrix[i]]
    return text
 
def create_reverse_dictionary(dictionary):
    dictionary_reverse = {}
    for word in dictionary:
        index = dictionary[word]
        dictionary_reverse[index] = word
    return dictionary_reverse
             
def generate_gifs(data, labels, elems, mvmts, num_frames, gif_width, gif_height, elem2str, period=None):
     
    if elems.shape != mvmts.shape:
        raise ValueError('The shapes of the element array and the movements array do not match. \
                         One Movement per Element is needed.')
     
    num_gifs = len(elems)
    gifs = np.zeros((num_gifs, num_frames, gif_height, gif_width), dtype=np.uint8)
    for i in range(num_gifs):
        gifs[i,:,:,:] = generate_gif(data, labels, elems[i], mvmts[i], gif_width, gif_height, num_frames, elem2str=elem2str, period=period)
         
    return gifs
 
def generate_gif(data, labels, elems, mvmts, gif_w, gif_h, num_frames, elem2str, period=None):
     
    if not isinstance(elems, (list, np.ndarray)) or not isinstance(mvmts, (list, np.ndarray)):
        raise ValueError(f'Invalid type of elems or mvmts array: {type(elems)}, {type(mvmts)}. Elems should be an array of length n_elems_per_gif. \
                         If there is only one element in your gif, make sure that both elems and mvmts have length 1.')
     
    # Initializing
    gif = np.zeros((num_frames, gif_h, gif_w))
    previous_start_positions = []
     
    # We loop through all elements that we want to add to the gif, i.e. number 7 and 4 moving left and right and up and down
    for i in range(len(elems)):
        elem = elems[i]
        mvmt = mvmts[i]
         
        # Get random image from the list of images corresponding to the given label
        current_elem_image = get_element_image(data, labels, elem, elem2str=elem2str)
         
        # Scale and rotate if necessary
        # current_elem_image = scale_and_rotate(current_elem_image, scale, rotation)
         
         
        # Get a trajectory function for this element. While defining the starting pt of trajectory, 
        # try not to position element in same starting position as other elements in this GIF
        trajectory = get_trajectory(current_elem_image, gif_w, gif_h, mvmt, previous_start_positions, period=period)
         
        # Appending start position to keep track and not repeat them
        start_x,start_y,current_elem_image=trajectory(0)
        previous_start_positions.append([int(start_x),int(start_y)])
         
        for fr in range(num_frames):
            # x, y of the elem_img represent the position of the first col and row, respectively.
            x,y,current_elem_image = trajectory(fr)
            x=int(x)
            y=int(y)
             
            # Get widths and height
            elem_h = current_elem_image.shape[0]
            elem_w = current_elem_image.shape[1]
 
            gif[fr, y:y+elem_h, x:x+elem_w] = overlap(gif[fr, y:y+elem_h, x:x+elem_w],current_elem_image)
             
    gif = np.array(gif, dtype=np.uint8)
    return gif  
 
def scale_and_rotate(img, scaling=[1,1], rotation=0):
    output_shape = [int(img.shape[1]*scaling[0]) , int(img.shape[0]*scaling[1])]
    img = resize(img, output_shape)
    return rotate(img, rotation)
     
def overlap(mat_a, mat_b):
    return np.maximum(mat_a, mat_b)      
 
def get_element_image(data, labels, elem_label, elem2str, idx = None):
    '''Get an element image corresponding to the given label (e.g. in MNIST case, if label is 9, get a 9 from all the available 9s in MNIST.
    Can return a random image or a particular one (indicated in idx)'''
    if isinstance(labels[0], str):
        str2elem = create_reverse_dictionary(elem2str)
        labels = [str2elem[l] for l in labels]
    if idx == None:
        label_indexes = np.where(labels==elem_label)[0]
        idx = np.random.choice(label_indexes)
    return data[idx]
     
def get_start_position(img_width, img_height, gif_width, gif_height, unwanted_positions=[], tol = 5, n_tries=5):
    '''Returns a start position, trying not to assign an already used starting position'''
     
    for n in range(n_tries):
        retry = False
 
        x = np.random.randint(0, gif_width-img_width-1)
        y = np.random.randint(0, gif_height-img_height-1)
        for point in unwanted_positions:
            if np.linalg.norm(np.array(point)-np.array([x,y])) <= tol:
                retry = True
                 
        if retry == False:
            break;
             
    return x,y

def generate_right_left_trajectory(start_x, start_y, img, gif_w, gif_h, period=10, pad =0):
    '''Generates right-left trajectory with constant step'''
     
    starting_t = period/2 * start_x/gif_w
    half_period = period/2
     
    img_w=img.shape[1]
     
    def trajectory(t):
        aux_t = (t+starting_t) % period
        a1 = (gif_w-img_w-2*pad)/half_period
        b1 = pad
        a2 = -a1
        b2 = (gif_w-img_w-2*pad)
        if aux_t <= half_period:
            x = a1*aux_t + b1
        else:
            x = a2*(aux_t-half_period) + b2
        return x,start_y,img
     
    return trajectory
 
def generate_up_down_trajectory(start_x, start_y, img, gif_w, gif_h, period=10, pad =0):
    '''Generates up-down trajectory with constant step'''
     
    img_h=img.shape[0]
     
    starting_t = period/2 * start_y/gif_h
    half_period = period/2
    def trajectory(t):
        aux_t = (t+starting_t) % period
        a1 = (gif_h-img_h-2*pad)/half_period
        b1 = pad
        a2 = -a1
        b2 = (gif_h-img_h-2*pad)
        if aux_t <= half_period:
            y = a1*aux_t + b1
        else:
            y = a2*(aux_t-half_period) + b2
 
        return start_x,y,img
     
    return trajectory
 
def generate_circular_trajectory(start_x, start_y, img, gif_w, gif_h, period=10, pad =0, clockwise=True):
    '''Generates clockwise circular trajectory. This function overwrites the originally sampled position if it is outside of the inner circle determined by the gif bounding box minus the img dimensions.
        It maintains the original angle of the digit w.r.t the center of the gif, but it resamples the radius 
        so that the entire rotation fits on the 64x64 canvas.'''
         
    img_w=img.shape[0]
    img_h=img.shape[1]
     
    center_x = (gif_w - img_w) /2
    center_y = (gif_h - img_h) /2
    norm_start_x = (start_x-center_x) / center_x
    norm_start_y = (start_y-center_y) / center_y
    start_theta = np.arctan2(norm_start_x, norm_start_y)
    current_r = np.sqrt(norm_start_x**2 + norm_start_y**2)
    e = 0.15
 
    if current_r > 1 or current_r <e:
        r = np.random.rand()*(1-e) + e
    else:
        r = current_r
     
    starting_t = period*start_theta/(2*np.pi)
    freq = 1/period
     
    def trajectory(t):
        if clockwise:
            aux_t = -t
        else:
            aux_t = t
        aux_t = aux_t + starting_t
        x = center_x*r*(np.sin(aux_t*2*np.pi*freq)) + center_x
        y = center_y*r*(np.cos(aux_t*2*np.pi*freq)) + center_y
        return x,y,img

    return trajectory
 
def generate_counterclockwise_trajectory(start_x, start_y, img, gif_w, gif_h, period=10, pad =0):
    return generate_circular_trajectory(start_x, start_y, img, gif_w, gif_h, period=period, pad =pad, clockwise=False)
 
def generate_clockwise_trajectory(start_x, start_y, img, gif_w, gif_h, period=10, pad =0):
    return generate_circular_trajectory(start_x, start_y, img, gif_w, gif_h, period=period, pad =pad, clockwise=True)
 
def generate_inplace_clockwise_trajectory(start_x, start_y, img, gif_w, gif_h, period=10, pad =0):
    return generate_inplace_rotation_trajectory(start_x, start_y, img, gif_w, gif_h, period=10, pad =0, clockwise=True)
     
def generate_inplace_counterclockwise_trajectory(start_x, start_y, img, gif_w, gif_h, period=10, pad =0):
    return generate_inplace_rotation_trajectory(start_x, start_y, img, gif_w, gif_h, period=10, pad =0, clockwise=False)
 
def generate_inplace_rotation_trajectory(start_x, start_y, img, gif_w, gif_h, period=10, pad =0, clockwise=False):
     
    img_w=img.shape[0]
    img_h=img.shape[1]    
     
    center_x = start_x+img_w/2
    center_y = start_y+img_h/2
     
    img_diagonal = np.sqrt((img_w/2)**2 + (img_h/2)**2)
     
    if center_x+img_diagonal >= gif_w or center_x-img_diagonal < 0:
        center_x = gif_w-img_diagonal-1
         
    if center_y+img_diagonal >= gif_h or center_y-img_diagonal <0:
        center_y = gif_h-img_diagonal-1
     
    def trajectory(t):
        if clockwise:
            aux_t = -t
        else:
            aux_t = t
             
        aux_img = rotate(img, aux_t*360/period)
        x = center_x - aux_img.shape[1]/2
        y = center_y - aux_img.shape[0]/2
        return x, y, aux_img
     
    return trajectory
     
def generate_inplace_trajectory(start_x, start_y, img, gif_w, gif_h, period=10, pad =0, clockwise=False):
    '''Generates an inplace trajectory: just returns the start position at every timestep.'''
    def trajectory(t):
        return start_x, start_y,img
    return trajectory
 
# deprecated
def update_right_left_sin(x,y,t,w,h,img_w,img_h,pad=1):
    freq = 1/10
    A = w-img_w - 2*pad
    x = A*(np.sin(t*2*np.pi*freq)+1)/2 + pad
    return x,y
 
# deprecated
def update_up_down_sin(x,y,t,w,h,img_w,img_h,pad=1):
    t = np.arccos(x)
    freq = 1/10
    A = w-img_w - 2*pad
    x = A*(np.sin(t*2*np.pi*freq)+1)/2 + pad
    return x,y
     
def get_trajectory(img, gif_w, gif_h, mvmt, unwanted_positions=[], period = None, mvmt_fct_dict=None):
 
    if mvmt_fct_dict == None:
        mvmt_fct_dict = {0:generate_right_left_trajectory, 1:generate_up_down_trajectory, 
                         2:generate_clockwise_trajectory, 3:generate_counterclockwise_trajectory,
                         4:generate_inplace_clockwise_trajectory, 5:generate_inplace_counterclockwise_trajectory}
 
    if mvmt not in mvmt_fct_dict:
        raise ValueError(f'Unknown movement type: {mvmt}')
         
    start_x, start_y = get_start_position(img.shape[1], img.shape[0], gif_w, gif_h, unwanted_positions, tol = 4, n_tries=5)
    if period == None:
        period_ = 10+(np.random.rand()-0.5)*3
    elif isinstance(period, (list, np.ndarray)):
        period_ = np.random.uniform(period[0], period[1])    
     
    return mvmt_fct_dict[mvmt](start_x, start_y, img, gif_w, gif_h, period_)
 
def build_dict_from_keys_and_values(dict_keys, dict_values):
    return dict(zip([s.decode('ascii') for s in dict_keys],dict_values))
     

def show_gif_subplots(sequences, captions = None, gif_size=64, n_gifs_to_show=2, start_idx=0, frames_dim = 1, word2num_dict = word2num_dict):
    frames = sequences.shape[frames_dim]
    n_gifs = sequences.shape[1-frames_dim]
 
    if isinstance(word2num_dict, (list, np.ndarray)):
        word2num_aux = {}
        for pair in word2num_dict:
            word2num_aux[pair[0]] = pair[1]
        word2num_dict = word2num_aux
     
    if n_gifs_to_show>n_gifs:
        n_gifs_to_show=n_gifs
     
    for i in range(start_idx,start_idx+n_gifs_to_show):
        plt.figure(figsize=[16,7])
        if isinstance(captions, (list, np.ndarray)):
            plt.suptitle(matrix2sent(captions[i], create_reverse_dictionary(word2num_dict)), y=0.66)
        for j in range(frames):
            f = sequences[i*frames_dim+j*(1-frames_dim),j*frames_dim+i*(1-frames_dim)].reshape(gif_size, gif_size)
            plt.subplot(1,frames,j+1)
            plt.xticks([])
            plt.yticks([])
            plt.imshow(f)
        plt.tight_layout()
        plt.show()

In [None]:
data_type = 'mnist'
elem_per_gif = 2
periods = [8,12]
MAX_CAPTION_LENGTH = 18


if data_type == 'mnist':
    data, labels, elem2str = load_data_mnist('./mnist_full.h5')
    template = 'the digit %s is %s'
elif data_type == 'icons':
    data, labels, elem2str = load_data_icons('./icons')
    template = 'the %s is %s'

dataset_name = f'gifcap_{str(data_type)}_{str(elem_per_gif)}elems_periods[{periods[0]},{periods[1]}]_caplen{MAX_CAPTION_LENGTH}.h5'
gifs, captions = create_gif_caption_dataset(data, labels, num_gifs=10000, moving_elements_per_gif=elem_per_gif, 
                                         allowed_elements=list(elem2str.values()), allowed_movements=[2,3,4,5], num_frames=10, 
                                         gif_w=64, gif_h=64, period=[8,12], tol=6, save=True, template=template, elem2str=elem2str, output_file=dataset_name)

# Testing the created dataset
with h5py.File(dataset_name,'r') as hf:
    gifs = np.array(hf.get('gifs'))
    captions = np.array(hf.get('captions'))
    dict_keys = np.array(hf.get('dict_keys'))
    dict_values = np.array(hf.get('dict_values'))

show_gif_subplots(gifs, captions, n_gifs_to_show=10, word2num_dict = build_dict_from_keys_and_values(dict_keys, dict_values))