In [1]:
import matplotlib.pyplot as plt
import h5py
import pickle
from matplotlib.colors import hsv_to_rgb
import numpy as np
from dae import ex
import os.path

In [None]:
# run a longer (30 iterations) evaluation for visualization
datasets = {
    'bars': 12, 
    'corners': 5,
    'shapes': 3,
    'multi_mnist': 3,
    'mnist_shape': 2,
    'simple_superpos':2
}
nr_iters = 30
nrows = 10
ncols = 12



for ds, k in datasets.items():
    results_filename = 'Results/{}_{}_{}.pickle'.format(ds, nr_iters, k)
    animation_dir = 'animations/{}'.format(ds)
    if not os.path.exists(animation_dir):
        os.makedirs(animation_dir)
    
    ex.run_command('evaluate', config_updates={
               'dataset.name': ds,
               'net_filename': 'Networks/best_{}_dae.h5'.format(ds),
               'em.k': k,
               'em.nr_iters': 30,
               'em.dump_results': results_filename,
               'seed': 42})    
    
    with h5py.File('/home/greff/Datasets/{}.h5'.format(ds)) as f:
        true_groups = f['test']['groups'][:]
    with open(results_filename, 'rb') as f:
        scores, likelihoods, results = pickle.load(f)
    
    if results.shape[-1] != 3:
        nr_colors = results.shape[-1]
        hsv_colors = np.ones((nr_colors, 3))
        hsv_colors[:, 0] = (np.linspace(0, 1, nr_colors, endpoint=False) + 2/3) % 1.0
        color_conv = hsv_to_rgb(hsv_colors)
        results = results.reshape(-1, nr_colors).dot(color_conv).reshape(results.shape[:-1] + (3,))
    
    for it in range(nr_iters+1):
        fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols, nrows))
        for r in range(nrows):
            for c in range(ncols):
                axes[r, c].imshow(results[ncols*r + c, it, 0, :, :, 0, :], interpolation='nearest')
                axes[r, c].set_xticks([])
                axes[r, c].set_yticks([])
        plt.subplots_adjust(wspace=0, hspace=0)
        fig.savefig(animation_dir + '/img_{:02d}.png'.format(it), bbox_inches='tight', pad_inches=0, dpi=72.26)

In [6]:
from subprocess import call

for ds in datasets:
    call(['convert', '-delay', '10', '-loop', '0', 'animations/{}/*.png'.format(ds), 'animations/{}.gif'.format(ds)])
