In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import scipy
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)
import time, os, pickle, pathlib

from loaddata import load_mnist, one_hot
import model
import trainutils

report_every = 10  # how often to print stats during training
n_runs       = 1   # how many times to repeat the whole scan across beta's
savedirbase  = str(pathlib.Path().absolute()) + '/saveddata/'

recreate_basemodel = True  # Starting model, with no info-theoretic regularization
base_cfg = {
    # SGD batch size
    'n_batch'         : 128 , 
    # whether to train noise variance with gradient descent, or other
    #  scipy optimizer loop that runs once an epoch
    'gradient_train_noisevar' : True, 
    # if gradient_train_noisevar=False, then size of batch for training noise variance
    'n_noisevar_batch': 1000,
    # whether to set noisevar to optimal value before training
    'initial_fitvar' : False,
}


# runtype = 'MNIST'
runtype = 'NoisyClassifier'
#runtype = 'Regression'
assert(runtype in ['MNIST', 'NoisyClassifier', 'Regression'])

if runtype == 'MNIST':
    data = load_mnist()
    savedir = runtype + '/v1'
    cfg = {
        'input_dims'  : 784,
        'entropyY'    : np.log(10), 
        'n_epochs'    : 150,
        'squaredIB'   : True,
        'encoder_arch': [(512,tf.nn.relu),(512,tf.nn.relu),(2,None)], 
        'decoder_arch': [(512,tf.nn.relu),(10,None)],
        'err_func'    : 'softmax_ce_v2',
    }

if runtype == 'NoisyClassifier':
    savedir = runtype + '/v1'
    # Data from artificial dataset used in Schwartz-Ziv and Tishby
    d1 = scipy.io.loadmat('data/g1.mat')
    d2 = scipy.io.loadmat('data/g2.mat')
    data = { 'trn_data' : d1['F'].astype('float32'), 'trn_labels': one_hot(d1['y'].flat),
             'tst_data' : d2['F'].astype('float32'), 'tst_labels' : one_hot(d2['y'].flat)}
    cfg = {
        'input_dims'    : 12,
        'entropyY'      : np.log(2),
        'n_epochs'      : 150,
        'squaredIB'     : False,
        'encoder_arch'  : [(20,tf.nn.relu),(20,tf.nn.relu),(2,None)], 
        'decoder_arch'  : [(20,tf.nn.relu),(2,None)],
        'err_func'      : 'softmax_ce_v2',
    }
    
if runtype == 'Regression':
    savedir = runtype + '/v5sq-eta'
    with open('data/regression.pkl', 'rb') as f:
        data = pickle.load(f)
    
    labelcov = np.cov(data['trn_labels'].T)
    entropyY = 0.5 * np.log(np.linalg.det(2*np.pi*np.exp(1)*labelcov))
    cfg = {
        'input_dims'       : data['trn_data'].shape[1],
        'entropyY'         : entropyY,
        'n_epochs'         : 200,
        'encoder_arch'     : [(100,tf.nn.relu),(100,tf.nn.relu),(2,None)], 
        'decoder_arch'     : [(100,tf.nn.relu),(10,None)],
        'err_func'         : 'mse',
        'squaredIB'        : True,
    }

savedir = savedirbase + savedir
for k, v in base_cfg.items():
    if k not in cfg: 
        cfg[k] = v
        
if cfg['squaredIB']:
    betavals = 10**np.linspace(-3, .5, 10, endpoint=True)
else:
    betavals = 10**np.linspace(-3, 0., 10, endpoint=True)


  from ._conv import register_converters as _register_converters


In [2]:
tf.reset_default_graph()
sess=tf.Session()

n = model.Net(input_dims   = cfg['input_dims'],
              encoder_arch = cfg['encoder_arch'], 
              decoder_arch = cfg['decoder_arch'],
              err_func     = cfg['err_func'],
              entropyY     = cfg['entropyY'],
              squaredIB    = cfg['squaredIB'],
              gradient_train_noisevar=cfg['gradient_train_noisevar'], 
              noisevar=0.01, init_beta=0.0)
saver = tf.train.Saver()


In [3]:
print("Making base model")
cfg2 = cfg.copy()
#cfg2['n_epochs'] = 10
cfg2['beta'] = 0.0
sess.run(tf.global_variables_initializer())
trainutils.train(sess, cfg2, data, n, n.ce_loss, report_every, fname=savedir+'/results-base')

save_path = saver.save(sess, savedir+'/_tf_basemodel')
print("Model saved in path: %s" % save_path)

Making base model

epoch: 1 | beta: 0.0000 | noisevar: 0.01 | kw: 0.00403459
ce:  0.69 /  0.69 | acc:  0.50 /  0.51 | nlIBloss: -0.00 / -0.01 | VIBloss: -0.00 / -0.01 | 
Ixt:  2.75 /  2.76 | Ixt_lb:  1.59 /  1.59 | vIxt:  8.53 /  8.53 | Iyt:  0.00 /  0.01 | 

epoch: 11 | beta: 0.0000 | noisevar: 0.01 | kw: 0.0419104
ce:  0.11 /  0.11 | acc:  0.95 /  0.95 | nlIBloss: -0.58 / -0.58 | VIBloss: -0.58 / -0.58 | 
Ixt:  5.51 /  5.51 | Ixt_lb:  4.24 /  4.24 | vIxt:  14.06 /  14.00 | Iyt:  0.58 /  0.58 | 

epoch: 21 | beta: 0.0000 | noisevar: 0.01 | kw: 0.0427475
ce:  0.06 /  0.07 | acc:  0.98 /  0.97 | nlIBloss: -0.64 / -0.63 | VIBloss: -0.64 / -0.63 | 
Ixt:  5.70 /  5.70 | Ixt_lb:  4.43 /  4.43 | vIxt:  14.75 /  14.55 | Iyt:  0.64 /  0.63 | 

epoch: 31 | beta: 0.0000 | noisevar: 0.01 | kw: 0.0439873
ce:  0.04 /  0.05 | acc:  0.98 /  0.98 | nlIBloss: -0.65 / -0.65 | VIBloss: -0.65 / -0.65 | 
Ixt:  5.77 /  5.78 | Ixt_lb:  4.50 /  4.52 | vIxt:  15.04 /  15.29 | Iyt:  0.65 /  0.65 | 

epoch: 41 |

In [4]:
#sess  = tf.Session()
#n = model.Net(encoder_arch=[(512,'relu'),(512,'relu'),(2,'relu')], decoder_arch=[(512,'relu'),],
#              trainable_sigma=cfg['train_sigma'], log_sigma2=-2, log_eta2=-20, init_beta=0.0)
#loader = tf.train.import_meta_graph(basemodelpath+'.meta')
#saver.restore(sess, basemodelpath)


In [5]:
for runndx in range(n_runs):
    for beta in betavals:
        if np.isclose(beta,0): 
            continue
        for mode in ['VIB','nlIB',]:
            saver.restore(sess, savedir+'/_tf_basemodel')

            if mode == 'nlIB':
                loss      = n.nlIB_loss
            elif mode == 'VIB':
                loss      = n.VIB_loss
            else:
                raise Exception('Unknown mode')

            print("Doing %s, beta=%0.4f" % (mode, beta))
            fname = savedir+'/results-%s-%0.5f-run%d' % (mode, beta, runndx)
            cfg2 = cfg.copy()
            cfg2['beta'] = beta
            cfg2['mode'] = mode
            
            trainutils.train(sess, cfg2, data, n, loss, report_every=report_every, fname=fname, fit_var=cfg2['initial_fitvar'])

            print()
            print()

Doing VIB, beta=0.0010

epoch: 1 | beta: 0.0010 | noisevar: 0.01 | kw: 0.0865619
ce:  0.00 /  0.00 | acc:  1.00 /  1.00 | nlIBloss: -0.69 / -0.69 | VIBloss: -0.67 / -0.67 | 
Ixt:  6.35 /  6.31 | Ixt_lb:  5.19 /  5.16 | vIxt:  24.09 /  23.26 | Iyt:  0.69 /  0.69 | 

epoch: 11 | beta: 0.0010 | noisevar: 0.00966099 | kw: 0.0574751
ce:  0.00 /  0.00 | acc:  1.00 /  1.00 | nlIBloss: -0.68 / -0.68 | VIBloss: -0.67 / -0.67 | 
Ixt:  6.13 /  6.13 | Ixt_lb:  4.96 /  4.97 | vIxt:  20.70 /  20.83 | Iyt:  0.69 /  0.69 | 

epoch: 21 | beta: 0.0010 | noisevar: 0.00915047 | kw: 0.039965
ce:  0.00 /  0.00 | acc:  1.00 /  1.00 | nlIBloss: -0.69 / -0.69 | VIBloss: -0.67 / -0.67 | 
Ixt:  6.02 /  6.02 | Ixt_lb:  4.84 /  4.85 | vIxt:  19.36 /  19.62 | Iyt:  0.69 /  0.69 | 

epoch: 31 | beta: 0.0010 | noisevar: 0.00884496 | kw: 0.0250927
ce:  0.00 /  0.00 | acc:  1.00 /  1.00 | nlIBloss: -0.69 / -0.69 | VIBloss: -0.67 / -0.67 | 
Ixt:  5.92 /  5.93 | Ixt_lb:  4.72 /  4.73 | vIxt:  18.52 /  18.65 | Iyt:  0.69 


epoch: 21 | beta: 0.0010 | noisevar: 0.00871037 | kw: 0.0359986
ce:  0.00 /  0.00 | acc:  1.00 /  1.00 | nlIBloss: -0.69 / -0.68 | VIBloss: -0.67 / -0.67 | 
Ixt:  6.35 /  6.34 | Ixt_lb:  5.14 /  5.13 | vIxt:  24.48 /  24.02 | Iyt:  0.69 /  0.69 | 

epoch: 31 | beta: 0.0010 | noisevar: 0.00801644 | kw: 0.0773905
ce:  0.04 /  0.04 | acc:  0.99 /  0.99 | nlIBloss: -0.65 / -0.65 | VIBloss: -0.63 / -0.63 | 
Ixt:  6.46 /  6.44 | Ixt_lb:  5.31 /  5.30 | vIxt:  23.65 /  23.38 | Iyt:  0.66 /  0.65 | 

epoch: 41 | beta: 0.0010 | noisevar: 0.00748686 | kw: 0.0443207
ce:  0.00 /  0.00 | acc:  1.00 /  1.00 | nlIBloss: -0.68 / -0.69 | VIBloss: -0.67 / -0.67 | 
Ixt:  6.44 /  6.46 | Ixt_lb:  5.26 /  5.28 | vIxt:  24.11 /  24.81 | Iyt:  0.69 /  0.69 | 

epoch: 51 | beta: 0.0010 | noisevar: 0.00759476 | kw: 0.0601967
ce:  0.00 /  0.00 | acc:  1.00 /  1.00 | nlIBloss: -0.69 / -0.69 | VIBloss: -0.67 / -0.67 | 
Ixt:  6.44 /  6.46 | Ixt_lb:  5.29 /  5.31 | vIxt:  24.56 /  24.96 | Iyt:  0.69 /  0.69 | 

epo

KeyboardInterrupt: 

In [None]:
# Code to plot activations
%matplotlib inline 
import matplotlib.pyplot as plt
import seaborn as sns
plt.figure(figsize=(10,10))
x, y = data['tst_data'], data['tst_labels']
mx = sess.run(n.encoder[-1], feed_dict={n.x: x})
var = sess.run(n.noisevar)
#plt.scatter(mx[:,0],mx[:,1], s=var, c=np.argmax(y,axis=1), alpha=1)
ax = plt.axes()
for r in mx:
    c = plt.Circle((r[0], r[1]), radius=np.sqrt(var), fc='none', alpha=0.05, ec='k')
    ax.add_patch(c)
plt.axis('scaled');

In [None]:
import entropy
#mx = sess.run(n.encoder[-1], feed_dict={n.x: x})

d = entropy.pairwise_distance(n.encoder[-1])
sess.run(entropy.GMM_entropy(d,1., 2, 'upper'), feed_dict={n.x: x})