In [None]:
from PIL import Image
import sys
import os
import math
import numpy as np
#import imageio as io
import scipy.io

###########################################################################################
# script to generate moving mnist video dataset (frame by frame) as described in
# [1] arXiv:1502.04681 - Unsupervised Learning of Video Representations Using LSTMs
#     Srivastava et al
# by Tencia Lee
# saves in hdf5, npz, or jpg (individual frames) format
###########################################################################################

channels = 3

# helper functions
def arr_from_img(im,shift=0):
    w,h=im.size
    arr=im.getdata()
    c = np.product(arr.size) / (w*h)
    return np.asarray(arr, dtype=np.float32).reshape((h,w,3)).transpose(2,1,0) / 255. - shift

def get_picture_array(X, index, shift=0):
    ch, w, h = X.shape[1], X.shape[2], X.shape[3]
    ret = ((X[index]+shift)*255.).reshape(ch,w,h).transpose(2,1,0).clip(0,255).astype(np.uint8)
    if ch == 1:
        ret=ret.reshape(h,w)
    return ret

# generates and returns video frames in uint8 array
def generate_moving_mnist(shape=(64,64), seq_len=20, seqs=100, num_sz=28, nums_per_image=3):
    
    # load mnist data
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    mnist, y = mnist.train.next_batch(60000)
    mnist = np.reshape(mnist, newshape=[-1, 1, 28, 28]) # np.reshape(mnist, newshape=[-1, 1, 28, 28])
    mnist = np.repeat(mnist, repeats=channels, axis=1)
    mnist_labels = []
    
    shapes = scipy.io.loadmat('./shapes.mat')
    shapes = shapes['shapes']
    shapes = shapes.astype(np.float32)
    mnist = np.transpose(shapes, axes=[0,3,1,2])
        
    width, height = shape
    lims = (x_lim, y_lim) = width-num_sz, height-num_sz
    dataset = np.empty((seq_len*seqs, channels, width, height), dtype=np.uint8)
    
    for seq_idx in xrange(seqs):
                
        if(seq_idx % 1000 == 0):
            print seq_idx
                                        
        # randomly generate direc/speed/position, calculate velocity vector
        direcs = np.pi * (np.random.rand(nums_per_image)*2 - 1)
        speeds = np.random.randint(5, size=nums_per_image)+2
        veloc = [(v*math.cos(d), v*math.sin(d)) for d,v in zip(direcs, speeds)]   
        
        arr = []
        for r in np.random.randint(0, mnist.shape[0], nums_per_image):
            arr.append(r)
        
        Image.fromarray(get_picture_array(mnist,r,shift=0))
        
        # select a random mnist image
        mnist_images = [Image.fromarray(get_picture_array(mnist,r,shift=0)).resize((num_sz,num_sz), Image.ANTIALIAS) for r in arr]
        mnist_labels.append([y[r] for r in arr])
        
        positions = [(np.random.rand()*x_lim, np.random.rand()*y_lim) for _ in xrange(nums_per_image)]
        for frame_idx in xrange(seq_len):
            canvases = [Image.new('RGB', (width,height)) for _ in xrange(nums_per_image)]
            canvas = np.zeros((channels,width,height), dtype=np.float32)
            
            for i,canv in enumerate(canvases):
                
                mask = np.array(mnist_images[i])
                mask = mask.astype(np.float32)
                                
                ones = mask > 0.0
                                
                mask[ones] = 255
                mask = Image.fromarray(mask.astype('uint8'))
                mask = mask.convert("1")
                                                
                canvases[0].paste(mnist_images[i], box= tuple(map(lambda p: int(round(p)), positions[i])), mask=mask)
                canvas = arr_from_img(canvases[0], shift=0)            
            
            # update positions based on velocity
            next_pos = [map(sum, zip(p,v)) for p,v in zip(positions, veloc)]
            # bounce off wall if a we hit one
            for i, pos in enumerate(next_pos):
                for j, coord in enumerate(pos):
                    if coord < -2 or coord > lims[j]+2:
                        veloc[i] = tuple(list(veloc[i][:j]) + [-1 * veloc[i][j]] + list(veloc[i][j+1:]))
            positions = [map(sum, zip(p,v)) for p,v in zip(positions, veloc)]
            
            # copy additive canvas to data array                        
            dataset[seq_idx*seq_len+frame_idx] = (canvas * 255).astype(np.uint8).clip(0,255)
    return dataset, mnist_labels

def main(dest, filetype='npy', frame_size=64, seq_len=200, seqs=100, num_sz=28, nums_per_image=2):
    dat, labels = generate_moving_mnist(shape=(frame_size,frame_size), seq_len=seq_len, seqs=seqs, \
                                num_sz=num_sz, nums_per_image=nums_per_image)
    
    n = seqs * seq_len
        
    #dat = np.reshape(dat, newshape=(seqs,channels,20,64,64))
    dat = np.reshape(dat, newshape=(seqs,seq_len,channels,64,64))
    dat = np.transpose(dat, axes=[0,1,3,4,2])
    #labels = np.reshape(labels, newshape=(seqs,20))
        
    if filetype == 'hdf5':
        import h5py
        from fuel.datasets.hdf5 import H5PYDataset
        def save_hd5py(dataset, destfile, indices_dict):
            f = h5py.File(destfile, mode='w')
            images = f.create_dataset('images', dataset.shape, dtype='uint8')
            images[...] = dataset
            split_dict = dict((k, {'images':v}) for k,v in indices_dict.iteritems())
            f.attrs['split'] = H5PYDataset.create_split_array(split_dict)
            f.flush()
            f.close()
        indices_dict = {'train': (0, n*9/10), 'test': (n*9/10, n)}
        # save image
        save_hd5py(dat, dest + '_images.h5', indices_dict)
        # save labels
        np.save(dest + '_labels', labels)
        
    elif filetype == 'npz':
        np.savez(dest, dat)
    elif filetype == 'jpg':
        for i in xrange(dat.shape[0]):
            Image.fromarray(get_picture_array(dat, i, shift=0)).save(os.path.join(dest, '{}.jpg'.format(i)))
    elif filetype == 'npy':
        np.save(dest + '-images', dat)
        #np.save(dest + '-labels', labels)

# training and validation set
main(dest='moving-shapes-2-tr', seq_len=20, seqs=11000)

# test sets
main(dest='moving-shapes-1-te', seq_len=200, seqs=100, nums_per_image=1)
main(dest='moving-shapes-2-te', seq_len=200, seqs=100, nums_per_image=2)
main(dest='moving-shapes-3-te', seq_len=200, seqs=100, nums_per_image=3)