In [1]:
from PIL import Image
import sys
import os
import math
import numpy as np

###########################################################################################
# script to generate moving mnist video dataset (frame by frame) as described in
# [1] arXiv:1502.04681 - Unsupervised Learning of Video Representations Using LSTMs
#     Srivastava et al
# by Tencia Lee
# saves in hdf5, npz, or jpg (individual frames) format
###########################################################################################

# helper functions
def arr_from_img(im,shift=0):
    w,h=im.size
    arr=im.getdata()
    c = np.product(arr.size) / (w*h)
    return np.asarray(arr, dtype=np.float32).reshape((h,w,c)).transpose(2,1,0) / 255. - shift

def get_picture_array(X, index, shift=0):
    ch, w, h = X.shape[1], X.shape[2], X.shape[3]
    ret = ((X[index]+shift)*255.).reshape(ch,w,h).transpose(2,1,0).clip(0,255).astype(np.uint8)
    if ch == 1:
        ret=ret.reshape(h,w)
    return ret

# generates and returns video frames in uint8 array
def generate_moving_mnist(shape=(64,64), seq_len=20, seqs=100, num_sz=28, nums_per_image=3):
        
    # load mnist data
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    mnist, y = mnist.train.next_batch(60000)
    mnist = np.reshape(mnist, newshape=[-1, 1, 28, 28]) # np.reshape(mnist, newshape=[-1, 1, 28, 28])
    mnist_labels = []
        
    width, height = shape
    lims = (x_lim, y_lim) = width-num_sz, height-num_sz
    dataset = np.empty((seq_len*seqs, 1, width, height), dtype=np.uint8)
    for seq_idx in xrange(seqs):
                
        if(seq_idx % 1000 == 0):
            print seq_idx
            
        # randomly generate direc/speed/position, calculate velocity vector
        direcs = np.pi * (np.random.rand(nums_per_image)*2 - 1)
        speeds = np.random.randint(5, size=nums_per_image)+2
        veloc = [(v*math.cos(d), v*math.sin(d)) for d,v in zip(direcs, speeds)]   
        
        arr = []
        for r in np.random.randint(0, mnist.shape[0], nums_per_image):
            arr.append(r)
        
        # select a random mnist image
        mnist_images = [Image.fromarray(get_picture_array(mnist,r,shift=0)).resize((num_sz,num_sz), Image.ANTIALIAS) for r in arr]
        mnist_labels.append([y[r] for r in arr])
        
        positions = [(np.random.rand()*x_lim, np.random.rand()*y_lim) for _ in xrange(nums_per_image)]
        for frame_idx in xrange(seq_len):
            canvases = [Image.new('L', (width,height)) for _ in xrange(nums_per_image)]
            canvas = np.zeros((1,width,height), dtype=np.float32)
            for i,canv in enumerate(canvases):
                
                canv.paste(mnist_images[i], box= tuple(map(lambda p: int(round(p)), positions[i])))
                canvas += arr_from_img(canv, shift=0)
            # update positions based on velocity
            next_pos = [map(sum, zip(p,v)) for p,v in zip(positions, veloc)]
            # bounce off wall if a we hit one
            for i, pos in enumerate(next_pos):
                for j, coord in enumerate(pos):
                    if coord < -2 or coord > lims[j]+2:
                        veloc[i] = tuple(list(veloc[i][:j]) + [-1 * veloc[i][j]] + list(veloc[i][j+1:]))
            positions = [map(sum, zip(p,v)) for p,v in zip(positions, veloc)]
            # copy additive canvas to data array
            dataset[seq_idx*seq_len+frame_idx] = (canvas * 255).astype(np.uint8).clip(0,255)
    return dataset, mnist_labels

def main(dest, filetype='npy', frame_size=64, seq_len=20, seqs=10000, num_sz=28, nums_per_image=2):
    dat, labels = generate_moving_mnist(shape=(frame_size,frame_size), seq_len=seq_len, seqs=seqs, \
                                num_sz=num_sz, nums_per_image=nums_per_image)
    
    n = seqs * seq_len
    
    dat = np.reshape(dat, newshape=(seqs,seq_len,64,64,1))
    #labels = np.reshape(labels, newshape=(seqs,20))
    print np.shape(dat)
    
    if filetype == 'hdf5':
        import h5py
        from fuel.datasets.hdf5 import H5PYDataset
        def save_hd5py(dataset, destfile, indices_dict):
            f = h5py.File(destfile, mode='w')
            images = f.create_dataset('images', dataset.shape, dtype='uint8')
            images[...] = dataset
            split_dict = dict((k, {'images':v}) for k,v in indices_dict.iteritems())
            f.attrs['split'] = H5PYDataset.create_split_array(split_dict)
            f.flush()
            f.close()
        indices_dict = {'train': (0, n*9/10), 'test': (n*9/10, n)}
        # save image
        save_hd5py(dat, dest + '_images.h5', indices_dict)
        # save labels
        np.save(dest + '_labels', labels)
        
    elif filetype == 'npz':
        np.savez(dest, dat)
    elif filetype == 'jpg':
        for i in xrange(dat.shape[0]):
            Image.fromarray(get_picture_array(dat, i, shift=0)).save(os.path.join(dest, '{}.jpg'.format(i)))
    elif filetype == 'npy':
        np.save(dest + '-images',dat)
        np.save(dest + '-labels', labels)

main(dest='mnist-hope')
print 'done'

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
(10000, 20, 64, 64, 1)
done


In [None]:
data = np.load( "./datasets/moving_mnist_100_2digits_images.npy" )
print np.shape(data)
data = np.reshape(data, newshape=(10000,20,64,64))
print np.shape(data)

import matplotlib.pyplot as plt
%matplotlib inline

plt.imshow(data[0,15,:,:], cmap='gray')

np.save("./datasets/moving_mnist_100_2digits_images.npy", data)

In [None]:
label = np.load( "./datasets/moving_mnist_100_2digits_labels.npy" )
# label2 = np.reshape(label, newshape=(10000,20))
# np.save("./datasets/moving_mnist_100_2digits_labels.npy", label2)

# plt.imshow(data[2,15,:,:], cmap='gray')
# print np.shape(label)
print label[2]

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

data = np.load( "./datasets/moving_mnist_60000_2digits_images.npy")
#plt.imshow(data[170,15,:,:], cmap='gray')

label = np.load( "./datasets/moving_mnist_60000_2digits_labels.npy" )
label = np.reshape(label, newshape=(60000,2,10))

x = np.argmax(label[:,0,:], axis=1)
y = np.argmax(label[:,1,:], axis=1)

plt.hist2d(x, y, bins=10)
plt.colorbar()