In [1]:
#%matplotlib inline 
import numpy as np
import matplotlib.pyplot as plt
import time
import os
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)
from loaddata import load_mnist
import model
import pickle
from pathlib import Path

data = load_mnist()
cfg = {
    'n_epochs'    : 150,
    'n_sgd'       : 128,
    'n_eta_sgd'   : 1000,
    'n_sigma_sgd' : 1000,
    'beta_epoch'  : 0,
    'train_eta'   : False,
    'train_sigma' : True,
    'encoder_arch': [(512,tf.nn.relu),(512,tf.nn.relu),(2,None)], 
    'decoder_arch': [(512,tf.nn.relu),],
              
}
report_every = 10
make_basemodel = False
savedir      = 'outdata/v3.2'
basemodelpath = str(Path().absolute()) +"/basemodel/v3"

  from ._conv import register_converters as _register_converters


In [2]:
def stats(epoch, data, n, do_print=False):
    cur_beta, cur_log_sigma2, cur_log_eta2 = sess.run([n.beta, n.log_sigma2, n.log_eta2])

    measures = [n.accuracy, n.nlIB_loss, n.VIB_loss, n.Ixt, n.Ixt_lb, n.vIxt, n.Iyt]
    trn_acc, trn_nlIBloss, trn_VIBloss, trn_Ixt, trn_Ixt_lb, trn_vIxt, trn_Iyt = sess.run(measures, feed_dict={n.x: data['train_data'][::10], n.y: data['train_labels'][::10]})
    tst_acc, tst_nlIBloss, tst_VIBloss, tst_Ixt, tst_Ixt_lb, tst_vIxt, tst_Iyt = sess.run(measures, feed_dict={n.x: data['test_data'][::10], n.y: data['test_labels'][::10]})

    activations_trn = sess.run(n.encoder[-1], feed_dict={n.x: data['train_data'][::10]})
    activations_tst = sess.run(n.encoder[-1], feed_dict={n.x: data['test_data']})
    cdata = {'epoch': epoch, 'beta': cur_beta, 'log_sigma2': cur_log_sigma2, 'log_eta2': cur_log_eta2,
             'trn_acc'     : trn_acc     , 'tst_acc'    : tst_acc, 
             'trn_nlIBloss': trn_nlIBloss, 'trn_VIBloss': trn_VIBloss,
             'tst_nlIBloss': tst_nlIBloss, 'tst_VIBloss': tst_VIBloss,
             'trn_Ixt'     : trn_Ixt     , 'trn_Ixt_lb' : trn_Ixt_lb, 
             'tst_Ixt'     : tst_Ixt     , 'tst_Ixt_lb' : tst_Ixt_lb,
             'tst_vIxt'    : tst_vIxt    , 'tst_vIxt'   : tst_vIxt,
             'trn_Iyt'     : trn_Iyt     , 'tst_Iyt'    : tst_Iyt,
             'activations_trn': activations_trn, 'activations_tst': activations_tst,
            }
    
    if do_print:
        print()
        print('epoch    : %d | beta: %0.4f | log(noisevar): %0.2f | log(kernelwidth): %0.2f'%
              (epoch+1, cur_beta, cur_log_sigma2, cur_log_eta2))
        print('acc      : % 0.2f / % 0.2f' % (trn_acc, tst_acc))
        print('nlIBloss : % 0.2f / % 0.2f' % (trn_nlIBloss, tst_nlIBloss))
        print('VIBloss  : % 0.2f / % 0.2f' % (trn_VIBloss , tst_VIBloss))
        print('I(X;T)   : % 0.2f / % 0.2f' % (trn_Ixt, tst_Ixt))
        print('I(X;T)lb : % 0.2f / % 0.2f' % (trn_Ixt_lb, tst_Ixt_lb))
        print('vI(X;T)  : % 0.2f / % 0.2f' % (trn_vIxt, tst_vIxt))
        print('I(Y;T)   : % 0.2f / % 0.2f' % (trn_Iyt, tst_Iyt))
        
    return cdata


def write_data(d, savedir, fname):
    if not os.path.exists(savedir):
        os.makedirs(savedir)
    with open(savedir+'/'+fname, 'wb') as fp:
        pickle.dump(d, fp)


def train(cfg, data, n, trainstep, beta, report_every, fname):
    n_sgd = cfg['n_sgd']
    n_mini_batches = int(len(data['train_labels']) / n_sgd)
    saved_data = []

    for epoch in range(cfg['n_epochs']):
        if epoch == cfg['beta_epoch']:
            sess.run(n.beta.assign(beta))
            #sess.run(n.beta.assign(0.01))
            #sess.run(n.allownoise.assign(1.))
        #if epoch > beta_epoch:
        #    cbeta = cdata['beta'] * 1.1
        #    if cbeta > beta:
        #        cbeta = beta
        #    sess.run(n.beta.assign(cbeta))

        
        cdata = stats(epoch, data, n, epoch % report_every == 0)
        saved_data.append(cdata)
        write_data([cfg, saved_data], savedir, fname)

        # randomize order of training data
        permutation = np.random.permutation(len(data['train_labels']))
        train_data = data['train_data'][permutation]
        train_labels = data['train_labels'][permutation]

#             if n.trainable_sigma:
#                 x_batch = train_data[:n_sigma_sgd]
#                 y_batch = train_labels[:n_sigma_sgd]
#                 n.sigma_optimizer.minimize(sess, feed_dict={n.x: x_batch, n.y: y_batch})

#             if train_eta:
#                 x_batch = train_data[:n_eta_sgd]
#                 n.eta_optimizer.minimize(sess, feed_dict={n.x: x_batch})

        for batch in range(n_mini_batches):
            # sample mini-batch
            x_batch = train_data[batch * n_sgd:(1 + batch) * n_sgd]
            y_batch = train_labels[batch * n_sgd:(1 + batch) * n_sgd]

            cparams = {n.x: x_batch, n.y: y_batch}
            sess.run(trainstep, feed_dict=cparams)

    cdata = stats(epoch+1, data, n, do_print=True)
    saved_data.append(cdata)
    write_data([cfg, saved_data], savedir, fname)

    return saved_data

In [3]:
tf.reset_default_graph()
sess=tf.Session()

n = model.Net(encoder_arch=cfg['encoder_arch'], 
              decoder_arch=cfg['decoder_arch'],
              trainable_sigma=cfg['train_sigma'], log_sigma2=-2, log_eta2=-20, init_beta=0.0)
saver = tf.train.Saver()


In [4]:
if make_basemodel or not os.path.exists(basemodelpath+'.meta'):
    print("Making base model")
    cfg2 = cfg.copy()
    cfg2['n_epochs'] = 50
    sess.run(tf.global_variables_initializer())
    train(cfg2, data, n, n.no_log_sigma2_trainstep, 0.0, report_every, fname='results-base')

    save_path = saver.save(sess, basemodelpath)
    print("Model saved in path: %s" % save_path)

In [5]:
#sess  = tf.Session()
#n = model.Net(encoder_arch=[(512,'relu'),(512,'relu'),(2,'relu')], decoder_arch=[(512,'relu'),],
#              trainable_sigma=cfg['train_sigma'], log_sigma2=-2, log_eta2=-20, init_beta=0.0)
#loader = tf.train.import_meta_graph(basemodelpath+'.meta')
#saver.restore(sess, basemodelpath)


In [None]:
for runndx in range(5):
    for beta in np.linspace(0, 1, 21, endpoint=True):
        if np.isclose(beta,0): 
            continue
        for mode in ['VIB','nlIB',]:
            saver.restore(sess, basemodelpath)

            if mode == 'nlIB':
                trainstep = n.nlIB_trainstep
            elif mode == 'VIB':
                trainstep = n.VIB_trainstep
            else:
                raise Exception('Unknown mode')

            print("Doing %s, %0.4f" % (mode, beta))
            fname = 'results-%s-%0.5f-run%d' % (mode, beta, runndx)
            train(cfg, data, n, trainstep, beta, report_every=report_every, fname=fname)

            print()
            print()

Doing VIB, 0.0500

epoch    : 1 | beta: 0.0500 | log(noisevar): -2.00 | log(kernelwidth): -20.00
acc      :  1.00 /  0.97
nlIBloss :  1.14 /  0.20
VIBloss  :  441943.12 /  470436.53
I(X;T)   :  8.28 /  6.82
I(X;T)lb :  7.62 /  6.62
vI(X;T)  :  2973.03 /  3067.37
I(Y;T)   :  2.29 /  2.13

epoch    : 11 | beta: 0.0500 | log(noisevar): -1.57 | log(kernelwidth): -20.00
acc      :  0.37 /  0.34
nlIBloss : -0.66 / -0.61
VIBloss  : -0.32 / -0.29
I(X;T)   :  1.42 /  1.38
I(X;T)lb :  0.82 /  0.79
vI(X;T)  :  2.95 /  2.89
I(Y;T)   :  0.76 /  0.71

epoch    : 21 | beta: 0.0500 | log(noisevar): -0.69 | log(kernelwidth): -20.00
acc      :  0.37 /  0.37
nlIBloss : -0.66 / -0.64
VIBloss  : -0.57 / -0.56
I(X;T)   :  1.36 /  1.35
I(X;T)lb :  0.72 /  0.70
vI(X;T)  :  1.91 /  1.85
I(Y;T)   :  0.75 /  0.73

epoch    : 31 | beta: 0.0500 | log(noisevar): -0.89 | log(kernelwidth): -20.00
acc      :  0.61 /  0.63
nlIBloss : -1.01 / -1.03
VIBloss  : -0.88 / -0.91
I(X;T)   :  2.21 /  2.18
I(X;T)lb :  1.26 /  1.

In [None]:
# with tf.Session() as sess:
import seaborn as sns
plt.figure(figsize=(10,10))
x, y = data['test_data'], data['test_labels']
mx = sess.run(n.encoder[-1], feed_dict={n.x: x})
plt.scatter(mx[:,0],mx[:,1], s=4, c=np.argmax(y,axis=1), alpha=0.05)
plt.axis('off');

In [None]:
import entropy
#mx = sess.run(n.encoder[-1], feed_dict={n.x: x})

d = entropy.pairwise_distance(n.encoder[-1])
sess.run(entropy.GMM_entropy(d,1., 2, 'upper'), feed_dict={n.x: x})