In [1]:
#%matplotlib inline 
import numpy as np
import matplotlib.pyplot as plt
import time
import os
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)
from loaddata import load_mnist
import model
import pickle

  from ._conv import register_converters as _register_converters


In [2]:
data = load_mnist()
tf.reset_default_graph()
n_epochs     = 200
n_sgd        = 128
n_eta_sgd    = 1000
beta         = 0.05
beta_epoch   = 20
report_every = 10

n = model.Net(encoder_arch=[(512,'relu'),(512,'relu'),], decoder_arch=[(512,'relu'),],
              trainable_sigma=False, log_sigma2=-4, init_beta=0.0)

# train model
sess=tf.Session()
#with tf.Session() as sess:
sess.run(tf.global_variables_initializer())

n_mini_batches = int(len(data['train_labels']) / n_sgd)

saved_data = []
fname = 'sqIBv1'

for epoch in range(n_epochs):
    # randomize order of training data
    permutation = np.random.permutation(len(data['train_labels']))
    train_data = data['train_data'][permutation]
    train_labels = data['train_labels'][permutation]

    x_batch = train_data[:n_eta_sgd]
    y_batch = train_labels[:n_eta_sgd]

    #if epoch == 0:
    if n.trainable_sigma:
        n.sigma_optimizer.minimize(sess, feed_dict={n.x: x_batch, n.y: y_batch})


    for batch in range(n_mini_batches):
        # sample mini-batch
        x_batch = train_data[batch * n_sgd:(1 + batch) * n_sgd]
        y_batch = train_labels[batch * n_sgd:(1 + batch) * n_sgd]

        cparams = {n.x: x_batch, n.y: y_batch}
        sess.run(n.trainstep, feed_dict=cparams)

    if epoch == beta_epoch:
        assign_op = n.beta.assign(beta)
        sess.run(assign_op)
    n.eta_optimizer.minimize(sess, feed_dict={n.x: x_batch})

    cur_beta, cur_log_sigma2, cur_log_eta2 = sess.run([n.beta, n.log_sigma2, n.log_eta2])

    tst_acc, tst_loss, tst_Ixt, tst_Iyt = \
        sess.run([n.accuracy, n.loss, n.Ixt, n.Iyt], feed_dict={n.x: data['test_data'][::10], n.y: data['test_labels'][::10]})
    trn_acc, trn_loss, trn_Ixt, trn_Iyt = \
        sess.run([n.accuracy, n.loss, n.Ixt, n.Iyt], feed_dict={n.x: data['train_data'][::10], n.y: data['train_labels'][::10]})

    cdata = {'epoch': epoch, 'beta': cur_beta, 'log_sigma2':cur_log_sigma2, 'log_eta2':cur_log_eta2,
             'trn_acc': trn_acc, 'tst_acc': tst_acc, 'trn_loss': trn_loss, 'tst_loss': tst_loss,
             'trn_Ixt': trn_Ixt, 'tst_Ixt': tst_Ixt, 'trn_Iyt' : trn_Iyt , 'tst_Iyt' : tst_Iyt}
    saved_data.append(cdata)
    
    with open('outdata/'+fname, 'wb') as fp:
        pickle.dump(saved_data, fp)

    if epoch % report_every == 0:
        print()
        print('epoch  : %d/%d | beta: %0.4f | noisevar: %0.4f | kernelwidth: %0.4f'%
              (epoch+1, n_epochs, cur_beta, np.exp(cur_log_sigma2), np.exp(cur_log_eta2)))
        print('acc    : % 0.4f / % 0.4f' % (trn_acc, tst_acc))
        print('loss   : % 0.4f / % 0.4f' % (trn_loss, tst_loss))
        print('I(X;T) : % 0.4f / % 0.4f' % (trn_Ixt, tst_Ixt))
        print('I(Y;T) : % 0.4f / % 0.4f' % (trn_Iyt, tst_Iyt))
            


epoch  : 1/200 | beta: 0.0000 | noisevar: 0.0183 | kernelwidth: 11.3015
acc    :  0.9115 /  0.9050
loss   : -1.9467 / -1.9432
I(X;T) :  9.4369 /  9.4321
I(Y;T) :  1.9467 /  1.9432

epoch  : 11/200 | beta: 0.0000 | noisevar: 0.0183 | kernelwidth: 10.6389
acc    :  0.9763 /  0.9650
loss   : -2.2088 / -2.1533
I(X;T) :  9.4864 /  9.4773
I(Y;T) :  2.2088 /  2.1533

epoch  : 21/200 | beta: 0.0500 | noisevar: 0.0183 | kernelwidth: 32.7269
acc    :  0.9885 /  0.9670
loss   :  2.9347 /  3.0239
I(X;T) :  10.1849 /  10.1875
I(Y;T) :  2.2519 /  2.1654

epoch  : 31/200 | beta: 0.0500 | noisevar: 0.0183 | kernelwidth: 0.0032
acc    :  0.9812 /  0.9690
loss   : -1.8297 / -1.7586
I(X;T) :  2.7762 /  2.7858
I(Y;T) :  2.2151 /  2.1466

epoch  : 41/200 | beta: 0.0500 | noisevar: 0.0183 | kernelwidth: 0.0058
acc    :  0.9895 /  0.9680
loss   : -1.8612 / -1.7288
I(X;T) :  2.7696 /  2.8106
I(Y;T) :  2.2447 /  2.1238

epoch  : 51/200 | beta: 0.0500 | noisevar: 0.0183 | kernelwidth: 0.0030
acc    :  0.9903 /

KeyboardInterrupt: 

In [None]:
# with tf.Session() as sess:
import seaborn as sns
plt.figure(figsize=(10,10))
x, y = data['test_data'], data['test_labels']
mx = sess.run(n.encoder[-1], feed_dict={n.x: x})
plt.scatter(mx[:,0],mx[:,1], s=4, c=np.argmax(y,axis=1), alpha=0.1)
plt.axis('off');

In [None]:
mx