In [None]:
import pickle as pkl
import numpy as np
import glob

import mnist_input
import mnist_utils
import utils

import matplotlib.pyplot as plt
%matplotlib inline
from matplotlib.backends.backend_pdf import PdfPages

In [None]:
class Hparams():
    pass

In [None]:
hparams = Hparams()
hparams.input_type = 'full-input'
hparams.num_input_images = 30
hparams.image_matrix = 0
hparams.image_shape = (28, 28, 1)
hparams.n_input = np.prod(hparams.image_shape)

xs_dict = mnist_input.model_input(hparams)

In [None]:
start = 10
stop = 20
assert start >= 0
assert stop <= hparams.num_input_images
images_nums = list(range(start, stop))

In [None]:
def view(patterns, images_nums, hparams, **kws):
    x_hats_dict = {}
    for model_type, pattern in zip(hparams.model_types, patterns):
        outfiles = [pattern.format(i) for i in images_nums]
        x_hats_dict[model_type] = {i: plt.imread(outfile) for i, outfile in enumerate(outfiles)}
    xs_dict_temp = {i : xs_dict[i] for i in images_nums}
    utils.image_matrix(xs_dict_temp, x_hats_dict, mnist_utils.view_image, hparams, **kws)

# Reconstructions

In [None]:
hparams.measurement_type = 'gaussian'
hparams.model_types = ['Lasso', 'VAE']
is_save = True

for num_measurements in [10, 25, 50, 100, 200, 300, 400, 500]:
    pattern1 = '../estimated/mnist/full-input/gaussian/0.1/' + str(num_measurements) + '/lasso/0.1/{0}.png'
    pattern2 = '../estimated/mnist/full-input/gaussian/0.1/' + str(num_measurements) + '/vae/0.0_1.0_0.1_adam_0.01_0.9_False_1000_10/{0}.png'
    patterns = [pattern1, pattern2]
    view(patterns, images_nums, hparams, alg_labels=False)
    
    base_path = '../results/mnist_reconstr_{}_orig_lasso_vae.pdf'
    save_path = base_path.format(num_measurements)
    utils.save_plot(is_save, save_path)

# Super-resolution

In [None]:
hparams.measurement_type = 'superres'
hparams.superres_factor = 2
hparams.model_types = ['VAE']
is_save = True

patterns = ['../estimated/mnist/full-input/superres/0.1/196/vae/0.0_1.0_0.1_momentum_0.01_0.9_False_500_10/{0}.png']
view(patterns, images_nums, hparams, alg_labels=False)

save_path = '../results/mnist_superres_orig_blurred_vae.pdf'
utils.save_plot(is_save, save_path)

# Projection Distance

In [None]:
hparams.measurement_type = ''
hparams.model_types = ['VAE']
is_save = True

patterns = ['../estimated/mnist/full-input/inpaint/0.0/784/vae/0.0_1.0_0.0_adam_0.1_0.9_False_1000_10/{0}.png']
view(patterns, images_nums, hparams, alg_labels=False)

save_path = '../results/mnist_projection_orig_vae.pdf'
utils.save_plot(is_save, save_path)

# End-to-end

In [None]:
hparams.measurement_type = 'fixed'
is_save = True
hparams.model_types = []
patterns = []

base_pattern = '../estimated/mnist/full-input/{0}/0.1/{1}/learned/50-200/{2}.png'
for measurement_type in ['fixed', 'learned']:
    for num_measurements in [10, 20, 30]:
        hparams.model_types.append('{}{}'.format(measurement_type.title(), num_measurements))
        patterns.append(base_pattern.format(measurement_type, num_measurements, '{0}'))

view(patterns, images_nums, hparams, alg_labels=True)
save_path = '../results/mnist_e2e_orig_fixed_learned.pdf'
utils.save_plot(is_save, save_path)

# Variation with noise

In [None]:
hparams.measurement_type = ''
hparams.model_types = ['Lasso', 'VAE']

for noise_std in [0.1, 1.0, 10.0, 100.0, 1000.0]:
    pattern1 = '../estimated/mnist/full-input/gaussian/' + str(noise_std) + '/100/lasso/0.1/{0}.png'
    pattern2 = '../estimated/mnist/full-input/gaussian/' + str(noise_std) + '/100/vae/0.0_1.0_' + str(10*noise_std**2) + '_adam_0.01_0.9_False_1000_10/{0}.png'
    patterns = [pattern1, pattern2]
    view(patterns, images_nums, hparams, alg_labels=True)