In [1]:
from tensorflow.examples.tutorials.mnist import input_data
from rebar_tf import *
import tensorflow as tf
import numpy as np
import os
from utils import *

BATCH_SIZE = 100
params = {'rounded_mnist':True,
          'batch_size':BATCH_SIZE,
          'binarized_mnist': False}


if params['binarized_mnist']:
    train_loader,valid_loader,test_loader = load_binarized_mnist(params)
else:
    train_loader,test_loader = load_pytorch_mnist(params=params,validation_split=False)

def encoder(x):
    if len(gs(x)) > 2:
        p = np.prod(gs(x)[1:])
        x = tf.reshape(x, [-1, p])
    #h1 = tf.layers.dense(2. * x - 1., 200, tf.nn.relu, name="encoder_1")
    #h2 = tf.layers.dense(h1, 200, tf.nn.relu, name="encoder_2")
    log_alpha = tf.layers.dense(x, 20, name="encoder_out")
    return log_alpha

def decoder(b):
    #h1 = tf.layers.dense(2. * b - 1., 200, tf.nn.relu, name="decoder_1")
    #h2 = tf.layers.dense(h1, 200, tf.nn.relu, name="decoder_2")
    log_alpha = tf.layers.dense(b, 784, name="decoder_out")
    return log_alpha

def Q_func(z):
    h1 = tf.layers.dense(2. * z - 1., 50, tf.nn.relu, name="q_1", use_bias=True)
    out = tf.layers.dense(h1, 1, name="q_out", use_bias=True)
    scale = tf.get_variable(
        "q_scale", shape=[1], dtype=tf.float32,
        initializer=tf.constant_initializer(0), trainable=True
    )
    return scale[0] * out


if __name__ == "__main__":
    TRAIN_DIR = "./rebar_new_u_and_v"
    reinforce = False
    relaxed = False
    if os.path.exists(TRAIN_DIR):
        print("Deleting existing train dir")
        import shutil

        shutil.rmtree(TRAIN_DIR)
    os.makedirs(TRAIN_DIR)
    sess = tf.Session()
    batch_size = BATCH_SIZE
    lr = .001
    #dataset = input_data.read_data_sets("MNIST_data/", one_hot=True)

    def to_vec(t):
        return tf.reshape(t, [-1])
    def from_vec(t):
        return tf.reshape(t, [batch_size, -1])

    x = tf.placeholder(tf.float32, [batch_size, 784])
    x_im = tf.reshape(x, [batch_size, 28, 28, 1])
    tf.summary.image("x_true", x_im)
    x_binary = tf.to_float(x > .5)
    log_alpha = encoder(x_binary)
    log_alpha_v = tf.reshape(log_alpha, [-1])
    evals = 0
    def loss(b):
        log_q_b_given_x = bernoulli_loglikelihood(b, log_alpha)
        log_q_b_given_x = tf.reduce_mean(tf.reduce_sum(log_q_b_given_x, axis=1))

        log_p_b = bernoulli_loglikelihood(b, tf.zeros_like(log_alpha))
        log_p_b = tf.reduce_mean(tf.reduce_sum(log_p_b, axis=1))

        with tf.variable_scope("decoder", reuse=evals>0):
            log_alpha_x_batch = decoder(b)
        log_p_x_given_b = bernoulli_loglikelihood(x_binary, log_alpha_x_batch)
        log_p_x_given_b = tf.reduce_mean(tf.reduce_sum(log_p_x_given_b, axis=1))
        # HACKY BS
        global evals
        if evals == 0:
            # if first eval make image summary
            a = tf.exp(log_alpha_x_batch)
            log_theta_x = a / (1 + a)
            log_theta = tf.reshape(log_theta_x, [batch_size, 28, 28, 1])
            tf.summary.image("x_pred", log_theta)
        evals += 1
        return -tf.expand_dims(log_p_x_given_b + log_p_b - log_q_b_given_x, 0)
    if relaxed:
        rebar_optimizer = RelaxedREBAROptimizer(sess, loss, Q_func, log_alpha=log_alpha, learning_rate=lr)
    else:
        rebar_optimizer = REBAROptimizer(sess, loss, log_alpha=log_alpha, learning_rate=lr)
    gen_loss = rebar_optimizer.f_b
    tf.summary.scalar("loss", gen_loss[0])
    gen_opt = tf.train.AdamOptimizer(lr)
    gen_vars = [v for v in tf.trainable_variables() if "decoder" in v.name]
    gen_gradvars = gen_opt.compute_gradients(gen_loss, var_list=gen_vars)
    gen_train_op = gen_opt.apply_gradients(gen_gradvars)

    alpha_grads = rebar_optimizer.reinforce if reinforce else rebar_optimizer.rebar
    inf_vars = [v for v in tf.trainable_variables() if "encode" in v.name]
    inf_grads = tf.gradients(log_alpha, inf_vars, grad_ys=alpha_grads)
    inf_gradvars = zip(inf_grads, inf_vars)
    inf_opt = tf.train.AdamOptimizer(lr)
    inf_train_op = inf_opt.apply_gradients(inf_gradvars)
    if relaxed:
        gradvars = inf_gradvars + gen_gradvars + rebar_optimizer.variance_gradvars + rebar_optimizer.Q_gradvars
    else:
        gradvars = inf_gradvars + gen_gradvars + rebar_optimizer.variance_gradvars
    for g, v in gradvars:
        tf.summary.histogram(v.name, v)
        tf.summary.histogram(v.name+"_grad", g)

    if reinforce:
        with tf.control_dependencies([gen_train_op, inf_train_op]):
            train_op = tf.no_op()
    else:
        with tf.control_dependencies([gen_train_op, inf_train_op, rebar_optimizer.variance_reduction_op]):
            train_op = tf.no_op()

    summ_op = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(TRAIN_DIR)
    sess.run(tf.global_variables_initializer())
    """
    for i in range(250000):
        batch_xs, _ = dataset.train.next_batch(100)
        if i % 100 == 0:
            loss, _, sum_str = sess.run([gen_loss, train_op, summ_op], feed_dict={x: batch_xs})
            summary_writer.add_summary(sum_str, i)
            print(i, loss[0])
        else:
            loss, _ = sess.run([gen_loss, train_op], feed_dict={x: batch_xs})
     """       
    losses=[]
    NUM_EPOCHS = 600
    for epoch in range(1,NUM_EPOCHS+1):
        train_loss = 0
        test_loss = 0
        for np_x,np_y in train_loader:
            #i += 1

            _,np_loss=sess.run([train_op,gen_loss],{
              x:np_x.numpy()
            })
            train_loss += np_loss

        for np_xtest,_ in test_loader:
            np_loss_test=sess.run(gen_loss,{
                  x:np_xtest.numpy()
                })
            test_loss += np_loss_test

        train_loss /= len(train_loader)
        test_loss /= len(test_loader)
        losses.append([train_loss,test_loss])
        print('epoch %d, train ELBO: %0.3f, test ELBO: %0.3f ' % (epoch,train_loss, test_loss))
    import torch
    torch.save(losses,'tf_rebar_20_binary_our_data.tar')

  global evals


Deleting existing train dir
INFO:tensorflow:Summary name encoder_out/kernel:0 is illegal; using encoder_out/kernel_0 instead.
INFO:tensorflow:Summary name encoder_out/kernel:0_grad is illegal; using encoder_out/kernel_0_grad instead.
INFO:tensorflow:Summary name encoder_out/bias:0 is illegal; using encoder_out/bias_0 instead.
INFO:tensorflow:Summary name encoder_out/bias:0_grad is illegal; using encoder_out/bias_0_grad instead.
INFO:tensorflow:Summary name decoder/decoder_out/kernel:0 is illegal; using decoder/decoder_out/kernel_0 instead.
INFO:tensorflow:Summary name decoder/decoder_out/kernel:0_grad is illegal; using decoder/decoder_out/kernel_0_grad instead.
INFO:tensorflow:Summary name decoder/decoder_out/bias:0 is illegal; using decoder/decoder_out/bias_0 instead.
INFO:tensorflow:Summary name decoder/decoder_out/bias:0_grad is illegal; using decoder/decoder_out/bias_0_grad instead.
INFO:tensorflow:Summary name eta:0 is illegal; using eta_0 instead.
INFO:tensorflow:Summary name eta

epoch 137, train ELBO: 131.387, test ELBO: 130.475 
epoch 138, train ELBO: 131.286, test ELBO: 130.662 
epoch 139, train ELBO: 131.237, test ELBO: 130.528 
epoch 140, train ELBO: 131.285, test ELBO: 130.508 
epoch 141, train ELBO: 131.220, test ELBO: 130.463 
epoch 142, train ELBO: 131.129, test ELBO: 130.493 
epoch 143, train ELBO: 131.143, test ELBO: 130.554 
epoch 144, train ELBO: 131.137, test ELBO: 130.435 
epoch 145, train ELBO: 131.104, test ELBO: 130.433 
epoch 146, train ELBO: 131.165, test ELBO: 130.335 
epoch 147, train ELBO: 131.041, test ELBO: 130.428 
epoch 148, train ELBO: 131.038, test ELBO: 130.292 
epoch 149, train ELBO: 131.009, test ELBO: 130.308 
epoch 150, train ELBO: 131.034, test ELBO: 130.317 
epoch 151, train ELBO: 131.075, test ELBO: 130.380 
epoch 152, train ELBO: 130.974, test ELBO: 130.188 
epoch 153, train ELBO: 130.926, test ELBO: 130.214 
epoch 154, train ELBO: 130.934, test ELBO: 130.182 
epoch 155, train ELBO: 130.945, test ELBO: 130.150 
epoch 156, t

epoch 295, train ELBO: 129.126, test ELBO: 128.652 
epoch 296, train ELBO: 129.143, test ELBO: 128.577 
epoch 297, train ELBO: 129.099, test ELBO: 128.447 
epoch 298, train ELBO: 129.070, test ELBO: 128.556 
epoch 299, train ELBO: 129.084, test ELBO: 128.445 
epoch 300, train ELBO: 129.067, test ELBO: 128.478 
epoch 301, train ELBO: 129.058, test ELBO: 128.584 
epoch 302, train ELBO: 129.076, test ELBO: 128.448 
epoch 303, train ELBO: 129.079, test ELBO: 128.483 
epoch 304, train ELBO: 129.049, test ELBO: 128.409 
epoch 305, train ELBO: 129.027, test ELBO: 128.458 
epoch 306, train ELBO: 129.051, test ELBO: 128.489 
epoch 307, train ELBO: 129.083, test ELBO: 128.468 
epoch 308, train ELBO: 129.092, test ELBO: 128.433 
epoch 309, train ELBO: 129.046, test ELBO: 128.397 
epoch 310, train ELBO: 129.017, test ELBO: 128.388 
epoch 311, train ELBO: 128.975, test ELBO: 128.369 
epoch 312, train ELBO: 129.005, test ELBO: 128.353 
epoch 313, train ELBO: 128.989, test ELBO: 128.393 
epoch 314, t

epoch 453, train ELBO: 128.404, test ELBO: 127.924 
epoch 454, train ELBO: 128.459, test ELBO: 127.839 
epoch 455, train ELBO: 128.302, test ELBO: 127.727 
epoch 456, train ELBO: 128.319, test ELBO: 127.597 
epoch 457, train ELBO: 128.349, test ELBO: 127.736 
epoch 458, train ELBO: 128.354, test ELBO: 127.788 
epoch 459, train ELBO: 128.327, test ELBO: 127.808 
epoch 460, train ELBO: 128.324, test ELBO: 127.731 
epoch 461, train ELBO: 128.325, test ELBO: 127.743 
epoch 462, train ELBO: 128.301, test ELBO: 127.766 
epoch 463, train ELBO: 128.295, test ELBO: 127.747 
epoch 464, train ELBO: 128.290, test ELBO: 127.646 
epoch 465, train ELBO: 128.283, test ELBO: 127.764 
epoch 466, train ELBO: 128.263, test ELBO: 127.706 
epoch 467, train ELBO: 128.325, test ELBO: 127.782 
epoch 468, train ELBO: 128.273, test ELBO: 127.749 
epoch 469, train ELBO: 128.271, test ELBO: 127.664 
epoch 470, train ELBO: 128.304, test ELBO: 127.661 
epoch 471, train ELBO: 128.303, test ELBO: 127.697 
epoch 472, t

In [2]:
torch.save(losses,'tf_relaxed_20_binary_our_data.tar')

In [15]:
sess = tf.Session()
x = tf.placeholder(tf.float32, [None, 784])
x_binary = tf.to_float(x > .5)
xs,_=dataset.train.next_batch(1000000)
sess.run(x_binary,{x:xs}).shape

(55000, 784)

In [3]:
a = torch.load('tf_relaxed_20_binary_our_data.tar')
for i,j in enumerate(a):
    print i,j
    if i ==250:
        break

0 [array([258.2353], dtype=float32), array([217.66817], dtype=float32)]
1 [array([210.13254], dtype=float32), array([201.85635], dtype=float32)]
2 [array([194.4551], dtype=float32), array([185.99367], dtype=float32)]
3 [array([180.98694], dtype=float32), array([174.87529], dtype=float32)]
4 [array([172.06741], dtype=float32), array([167.37355], dtype=float32)]
5 [array([165.40747], dtype=float32), array([161.9403], dtype=float32)]
6 [array([161.07353], dtype=float32), array([158.60512], dtype=float32)]
7 [array([157.73726], dtype=float32), array([155.48267], dtype=float32)]
8 [array([155.30998], dtype=float32), array([153.492], dtype=float32)]
9 [array([153.18279], dtype=float32), array([151.54903], dtype=float32)]
10 [array([151.35089], dtype=float32), array([149.86533], dtype=float32)]
11 [array([150.12944], dtype=float32), array([148.66196], dtype=float32)]
12 [array([148.93877], dtype=float32), array([147.54285], dtype=float32)]
13 [array([147.83272], dtype=float32), array([146.454

236 [array([129.24696], dtype=float32), array([128.57494], dtype=float32)]
237 [array([129.22856], dtype=float32), array([128.62515], dtype=float32)]
238 [array([129.19905], dtype=float32), array([128.58641], dtype=float32)]
239 [array([129.19212], dtype=float32), array([128.45853], dtype=float32)]
240 [array([129.20209], dtype=float32), array([128.63304], dtype=float32)]
241 [array([129.17265], dtype=float32), array([128.51517], dtype=float32)]
242 [array([129.17487], dtype=float32), array([128.53528], dtype=float32)]
243 [array([129.17577], dtype=float32), array([128.57623], dtype=float32)]
244 [array([129.19937], dtype=float32), array([128.48648], dtype=float32)]
245 [array([129.17049], dtype=float32), array([128.51216], dtype=float32)]
246 [array([129.11565], dtype=float32), array([128.51018], dtype=float32)]
247 [array([129.20868], dtype=float32), array([128.47624], dtype=float32)]
248 [array([129.13792], dtype=float32), array([128.45961], dtype=float32)]
249 [array([129.14397], d