In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

### load ellipses

In [2]:
def load_ellipses():
    dataset = np.load('data/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', encoding = 'latin1')
    # ['metadata', 'imgs', 'latents_classes', 'latents_values']

    lat_vals = dataset['latents_values']
    #Color: white
    #Shape: square, ellipse, heart
    #Scale: 6 values linearly spaced in [0.5, 1]
    #Orientation: 40 values in [0, 2 pi]
    #Position X: 32 values in [0, 1]
    #Position Y: 32 values in [0, 1]
    
    elipses_idxs = np.where(lat_vals[:,1] == 2)[0] # "== 2" - ellipses
    elipses_imgs = dataset['imgs'][elipses_idxs]
    elipses_lat_vals = lat_vals[elipses_idxs]
    
    elipses_imgs = elipses_imgs.reshape(elipses_imgs.shape[0], -1)
    
    return elipses_imgs, elipses_lat_vals

In [3]:
ellipses_imgs, ellipses_lat_vals = load_ellipses()

In [4]:
#plt.imshow(ellipses_imgs[np.random.choice(range(ellipses_imgs.shape[0]))])

### create model

In [5]:
from absl import app as absl_app
from absl import flags as absl_flags

absl_flags.DEFINE_integer("epoch_size", 2000, "epoch size")
absl_flags.DEFINE_integer("batch_size", 64, "batch size")
absl_flags.DEFINE_float("gamma", 100.0, "gamma param for latent loss")
absl_flags.DEFINE_float("capacity_limit", 20.0, "encoding capacity limit param for latent loss")
absl_flags.DEFINE_integer("capacity_change_duration", 100000, "encoding capacity change duration")
absl_flags.DEFINE_float("learning_rate", 5e-4, "learning rate")
absl_flags.DEFINE_string("checkpoint_dir", "checkpoints", "checkpoint directory")
absl_flags.DEFINE_string("log_file", "./log", "log file directory")
absl_flags.DEFINE_boolean("training", True, "training or not")

flags = absl_flags.FLAGS

In [6]:
_ = flags(["dummy"]) # jupyter notebook hack

In [7]:
from model import VAE

sess = tf.Session()    
model = VAE(gamma=flags.gamma,
    capacity_limit=flags.capacity_limit,
    capacity_change_duration=flags.capacity_change_duration,
    learning_rate=flags.learning_rate)

sess.run(tf.global_variables_initializer())






Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where





In [16]:
n_images = ellipses_imgs.shape[0]
n_batches = n_images // flags.batch_size
last_batch_size = n_images - n_batches * flags.batch_size
shuffled_img_indices = list(range(n_images))

print('%d images, %d batches %d images each, and %d images left'
      % (n_images, n_batches, flags.batch_size, last_batch_size))
assert(last_batch_size == 0)

for epoch in range(flags.epoch_size):
    print('epoch %d of %d (%d batches %d images each)' % (epoch, flags.epoch_size, n_batches, flags.batch_size))
    
    # reshuffle image indices each epoch
    np.random.shuffle(shuffled_img_indices)
    
    step = 0
    for i in range(n_batches):
        # Generate image batch
        batch_indices = shuffled_img_indices[flags.batch_size*i : flags.batch_size*(i+1)]
        batch_imgs = ellipses_imgs[batch_indices]

        # Fit training using batch data
        reconstr_loss, latent_loss, summary_str = model.partial_fit(sess, batch_imgs, step)
        #summary_writer.add_summary(summary_str, step)

        if step % 100 == 0:
            print('.', end='')
        step += 1

    print(' done')


245760 images, 3840 batches 64 images each, and 0 images in the last batch
