In [1]:
import matplotlib.pyplot as plt
import h5py
import pickle
from matplotlib.colors import hsv_to_rgb
import numpy as np
from dae import ex
import os.path

In [2]:
# run a longer (30 iterations) evaluation for visualization
datasets = {
    'bars': 12, 
    'corners': 5,
    'shapes': 3,
    'multi_mnist': 3,
    'mnist_shape': 2,
    'simple_superpos':2
}
nr_iters = 30
nrows = 10
ncols = 12



for ds, k in datasets.items():
    results_filename = 'Results/{}_{}_{}.pickle'.format(ds, nr_iters, k)
    animation_dir = 'animations/{}'.format(ds)
    if not os.path.exists(animation_dir):
        os.makedirs(animation_dir)
    
    ex.run_command('evaluate', config_updates={
               'dataset.name': ds,
               'net_filename': 'Networks/best_{}_dae.h5'.format(ds),
               'em.k': k,
               'em.nr_iters': 30,
               'em.dump_results': results_filename,
                'em.nr_samples': nrows * ncols,
               'seed': 42})    
    
    with h5py.File('/home/greff/Datasets/{}.h5'.format(ds)) as f:
        true_groups = f['test']['groups'][:]
    with open(results_filename, 'rb') as f:
        scores, likelihoods, results = pickle.load(f)
    
    if results.shape[-1] != 3:
        nr_colors = results.shape[-1]
        hsv_colors = np.ones((nr_colors, 3))
        hsv_colors[:, 0] = (np.linspace(0, 1, nr_colors, endpoint=False) + 2/3) % 1.0
        color_conv = hsv_to_rgb(hsv_colors)
        results = results.reshape(-1, nr_colors).dot(color_conv).reshape(results.shape[:-1] + (3,))
    
    for it in range(nr_iters+1):
        fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols, nrows))
        for r in range(nrows):
            for c in range(ncols):
                axes[r, c].imshow(results[ncols*r + c, it, 0, :, :, 0, :], interpolation='nearest')
                axes[r, c].set_xticks([])
                axes[r, c].set_yticks([])
        plt.subplots_adjust(wspace=0, hspace=0)
        fig.savefig(animation_dir + '/img_{:02d}.png'.format(it), bbox_inches='tight', pad_inches=0, dpi=72.26)

INFO - binding_dae - Running command 'evaluate'
INFO - binding_dae - Started


Average Score: 0.6794
Average Confidence: 0.8952
wrote the results to Results/multi_mnist_30_3.pickle

INFO - binding_dae - Result: 0.6794270494870049
INFO - binding_dae - Completed after 0:03:08
INFO - binding_dae - Running command 'evaluate'
INFO - binding_dae - Started



Average Score: 0.8916

INFO - binding_dae - Result: 0.8916011324539921
INFO - binding_dae - Completed after 0:00:04
INFO - binding_dae - Running command 'evaluate'
INFO - binding_dae - Started



Average Confidence: 0.9474
wrote the results to Results/simple_superpos_30_2.pickle
Average Score: 0.9537
Average Confidence: 0.9470
wrote the results to Results/shapes_30_3.pickle

INFO - binding_dae - Result: 0.9537316145962019
INFO - binding_dae - Completed after 0:01:00
INFO - binding_dae - Running command 'evaluate'
INFO - binding_dae - Started



Average Score: 0.5882
Average Confidence: 0.9480
wrote the results to Results/mnist_shape_30_2.pickle

INFO - binding_dae - Result: 0.588232954116256
INFO - binding_dae - Completed after 0:00:50
INFO - binding_dae - Running command 'evaluate'
INFO - binding_dae - Started



Average Score: 0.9853
Average Confidence: 0.9785
wrote the results to Results/bars_30_12.pickle

INFO - binding_dae - Result: 0.9853389285150709
INFO - binding_dae - Completed after 0:02:33
INFO - binding_dae - Running command 'evaluate'
INFO - binding_dae - Started



Average Score: 0.8972
Average Confidence: 0.9830
wrote the results to Results/corners_30_5.pickle

INFO - binding_dae - Result: 0.8971806767597252
INFO - binding_dae - Completed after 0:01:31







In [8]:
# run a longer (30 iterations) evaluation for visualization
datasets = {
    'bars': 12, 
    'corners': 5,
    'shapes': 3,
    'multi_mnist': 3,
    'mnist_shape': 2,
#    'simple_superpos':2
}
nr_iters = 30
nrows = 10
ncols = 12



for ds, k in datasets.items():
    results_filename = 'Results/{}_{}_{}_train_multi.pickle'.format(ds, nr_iters, k)
    animation_dir = 'animations/{}_train_multi'.format(ds)
    if not os.path.exists(animation_dir):
        os.makedirs(animation_dir)
    
    ex.run_command('evaluate', config_updates={
               'dataset.name': ds,
               'net_filename': 'Networks/best_{}_dae_train_multi.h5'.format(ds),
               'em.k': k,
               'em.e_step': 'max',
               'em.nr_iters': nr_iters,
               'em.dump_results': results_filename,
               'em.nr_samples': nrows * ncols,
               'seed': 42})
    
    with h5py.File('/home/greff/Datasets/{}.h5'.format(ds)) as f:
        input_image = f['test']['default'][:]
        true_groups = f['test']['groups'][:]
    with open(results_filename, 'rb') as f:
        scores, likelihoods, results = pickle.load(f)
    
    if results.shape[-1] != 3:
        nr_colors = results.shape[-1]
        hsv_colors = np.ones((nr_colors, 3))
        hsv_colors[:, 0] = (np.linspace(0, 1, nr_colors, endpoint=False) + 2/3) % 1.0
        color_conv = hsv_to_rgb(hsv_colors)
        results = results.reshape(-1, nr_colors).dot(color_conv).reshape(results.shape[:-1] + (3,))
    
    for it in range(nr_iters+1):
        fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols, nrows))
        for r in range(nrows):
            for c in range(ncols):
                groups = results[ncols*r + c, it, 0, :, :, 0, :]
                in_img = input_image[0, ncols*r + c] * 0.7 + 0.3
                
                axes[r, c].imshow(groups * in_img, interpolation='nearest')
                axes[r, c].set_xticks([])
                axes[r, c].set_yticks([])
        plt.subplots_adjust(wspace=0, hspace=0)
        fig.savefig(animation_dir + '/img_{:02d}.png'.format(it), bbox_inches='tight', pad_inches=0, dpi=72.26)

INFO - binding_dae - Running command 'evaluate'
INFO - binding_dae - Started
INFO - binding_dae - Result: 0.29725815533634
INFO - binding_dae - Completed after 0:00:07
INFO - binding_dae - Running command 'evaluate'
INFO - binding_dae - Started


Average Score: 0.2973
Average Confidence: 1.0000
wrote the results to Results/mnist_shape_30_2_train_multi.pickle
Average Score: 0.7046
Average Confidence: 1.0000
wrote the results to Results/corners_30_5_train_multi.pickle

INFO - binding_dae - Result: 0.7045697660668432
INFO - binding_dae - Completed after 0:00:10
INFO - binding_dae - Running command 'evaluate'
INFO - binding_dae - Started



Average Score: 0.8507
Average Confidence: 1.0000
wrote the results to Results/bars_30_12_train_multi.pickle

INFO - binding_dae - Result: 0.8507084656354675
INFO - binding_dae - Completed after 0:00:18
INFO - binding_dae - Running command 'evaluate'
INFO - binding_dae - Started



Average Score: 0.6322
Average Confidence: 1.0000
wrote the results to Results/multi_mnist_30_3_train_multi.pickle

INFO - binding_dae - Result: 0.6322366970257197
INFO - binding_dae - Completed after 0:00:13
INFO - binding_dae - Running command 'evaluate'
INFO - binding_dae - Started



Average Score: 0.7558
Average Confidence: 1.0000
wrote the results to Results/shapes_30_3_train_multi.pickle

INFO - binding_dae - Result: 0.7558324914261669
INFO - binding_dae - Completed after 0:00:07







In [1]:
import os.path

In [2]:
subdirs = [f for f in os.listdir('animations') if os.path.isdir(os.path.join('animations', f))]

In [3]:
from subprocess import call

for d in subdirs:
    call(['convert', '-delay', '20', '-loop', '0', 'animations/{}/*.png'.format(d), 'animations/{}.gif'.format(d)])
