In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import scipy
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)
import time, os, pickle, pathlib

import model
import trainutils

report_every = 10  # how often to print stats during training
n_runs       = 1   # how many times to repeat the whole scan across beta's
savedirbase  = str(pathlib.Path().absolute()) + '/saveddata/'

optimizer = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.9, beta2=0.999)

base_cfg = {
    'n_batch'         : 128 ,       # SGD batch size
    'train_noisevar'  : 'gradient', # train noise variance with gradient descent ('gradient'), 
                                    #  scipy optimizer loop ('scipy'), or leave fixed ('none')
    'n_noisevar_batch': 1000,       # batch size for training noise variance when train_noisevar='scipy'
    'initial_fitvar'  : False,      # whether to set noisevar to optimal value before training
    'squaredIB'       : True,       # optimize I(Y;T)-beta*I(X;T) or I(Y;T)-beta*I(X;T)^2 
    'err_func'        : 'softmax_ce', # 'softmax_ce' for classification, 'mse' for regression  
}


runtype = 'MNIST'
#runtype = 'NoisyClassifier'
#runtype = 'Regression'
assert(runtype in ['MNIST', 'NoisyClassifier', 'Regression'])

if runtype == 'MNIST':
    data = trainutils.load_mnist()
    savedir = runtype + '/v2'
    cfg = {
        'input_dims'  : 784,
        'entropyY'    : np.log(10), 
        'n_epochs'    : 150,
        'squaredIB'   : True,
        'encoder_arch': [(512,tf.nn.relu),(512,tf.nn.relu),(2,None)], 
        'decoder_arch': [(512,tf.nn.relu),(10,None)],
    }

elif runtype == 'NoisyClassifier':
    savedir = runtype + '/v5sq'
    # Data from artificial dataset used in Schwartz-Ziv and Tishby
    d1 = scipy.io.loadmat('data/g1.mat')
    d2 = scipy.io.loadmat('data/g2.mat')
    data = { 'trn_X' : d1['F'].astype('float32'), 'trn_Y': one_hot(d1['y'].flat),
             'tst_X' : d2['F'].astype('float32'), 'tst_Y': one_hot(d2['y'].flat)}
    cfg = {
        'input_dims'    : 12,
        'entropyY'      : np.log(2),
        'n_epochs'      : 200,
        'squaredIB'     : True,
        'encoder_arch'  : [(20,tf.nn.relu),(20,tf.nn.relu),(2,None)], 
        'decoder_arch'  : [(20,tf.nn.relu),(2,None)],
    }
    
elif runtype == 'Regression':
    savedir = runtype + '/v5sq-eta'
    # data generated by makeregressiondata.py
    with open('data/regression-100-10.pkl', 'rb') as f:
        data = pickle.load(f)
    
    labelcov = np.cov(data['trn_Y'].T)
    entropyY = 0.5 * np.log(np.linalg.det(2*np.pi*np.exp(1)*labelcov))
    cfg = {
        'input_dims'       : data['trn_X'].shape[1],
        'entropyY'         : entropyY,
        'n_epochs'         : 200,
        'encoder_arch'     : [(100,tf.nn.relu),(100,tf.nn.relu),(2,None)], 
        'decoder_arch'     : [(100,tf.nn.relu),(10,None)],
        'err_func'         : 'mse',
        'squaredIB'        : True,
    }
    
else:
    raise Exception('unknown runtype')
    
savedir = savedirbase + savedir
for k, v in base_cfg.items():
    if k not in cfg: 
        cfg[k] = v
cfg['optimizer'] = repr(optimizer)

if cfg['squaredIB']:
    betavals = 10**np.linspace(-3, .5, 10, endpoint=True)
else:
    betavals = 10**np.linspace(-1, 0., 10, endpoint=True)


  from ._conv import register_converters as _register_converters


In [2]:
tf.reset_default_graph()
sess=tf.Session()


n = model.Net(input_dims   = cfg['input_dims'],
              encoder_arch = cfg['encoder_arch'], 
              decoder_arch = cfg['decoder_arch'],
              err_func     = cfg['err_func'],
              entropyY     = cfg['entropyY'],
              trainable_noisevar = cfg['train_noisevar']=='gradient', 
              noisevar     = 0.01)
saver = tf.train.Saver()


In [3]:
print("Making base model")
#cfg2['n_epochs'] = 10
sess.run(tf.global_variables_initializer())

trainutils.train(sess, 'ce', 0.0, cfg, data, n, optimizer, report_every, fname=savedir+'/results-base')

save_path = saver.save(sess, savedir+'/_tf_basemodel')
print("Model saved in path: %s" % save_path)

Making base model

mode: ce epoch: 1 | beta: 0.0000 | noisevar: 0.01 | kw: 0.0137392
ce:  2.308/ 2.307 | acc:  0.097/ 0.101 | loss:  2.308/ 2.307 | 
Ixt:  3.291/ 3.251 | Ixt_lb:  2.081/ 2.050 | vIxt:  9.138/ 9.116 | Iyt: -0.005/-0.004 | 

mode: ce epoch: 11 | beta: 0.0000 | noisevar: 0.01 | kw: 1.63276
ce:  0.066/ 0.082 | acc:  0.980/ 0.974 | loss:  0.066/ 0.082 | 
Ixt:  10.180/ 10.191 | Ixt_lb:  9.196/ 9.201 | vIxt:  343.521/ 332.489 | Iyt:  2.236/ 2.220 | 

mode: ce epoch: 21 | beta: 0.0000 | noisevar: 0.01 | kw: 2.28921
ce:  0.032/ 0.026 | acc:  0.990/ 0.991 | loss:  0.032/ 0.026 | 
Ixt:  10.622/ 10.612 | Ixt_lb:  9.629/ 9.612 | vIxt:  562.379/ 563.456 | Iyt:  2.271/ 2.277 | 

mode: ce epoch: 31 | beta: 0.0000 | noisevar: 0.01 | kw: 3.33712
ce:  0.031/ 0.037 | acc:  0.989/ 0.989 | loss:  0.031/ 0.037 | 
Ixt:  10.881/ 10.905 | Ixt_lb:  9.925/ 9.952 | vIxt:  852.996/ 872.600 | Iyt:  2.272/ 2.265 | 

mode: ce epoch: 41 | beta: 0.0000 | noisevar: 0.01 | kw: 5.4792
ce:  0.015/ 0.012 | ac

In [4]:
#sess  = tf.Session()
#n = model.Net(encoder_arch=[(512,'relu'),(512,'relu'),(2,'relu')], decoder_arch=[(512,'relu'),],
#              trainable_sigma=cfg['train_sigma'], log_sigma2=-2, log_eta2=-20, init_beta=0.0)
#loader = tf.train.import_meta_graph(basemodelpath+'.meta')
#saver.restore(sess, basemodelpath)


In [None]:
for runndx in range(n_runs):
    for beta in betavals:
        if np.isclose(beta,0): 
            continue
        for mode in ['VIB','nlIB',]:
            saver.restore(sess, savedir+'/_tf_basemodel')
            print("Doing %s, beta=%0.4f" % (mode, beta))
            fname = savedir+'/results-%s-%0.5f-run%d' % (mode, beta, runndx)
            trainutils.train(sess, mode, beta, cfg, data, n, optimizer, report_every=report_every, fname=fname, fit_var=cfg['initial_fitvar'])

            print()
            print()

Doing VIB, beta=0.0010

mode: VIB epoch: 1 | beta: 0.0010 | noisevar: 0.01 | kw: 48.2735
ce:  0.046/ 0.041 | acc:  0.991/ 0.993 | loss:  325164.250/ 342443.438 | 
Ixt:  13.624/ 13.685 | Ixt_lb:  12.685/ 12.737 | vIxt:  18032.373/ 18505.287 | Iyt:  2.257/ 2.261 | 


In [None]:
# Code to plot activations
%matplotlib inline 
import matplotlib.pyplot as plt
import seaborn as sns
if False:
    plt.figure(figsize=(10,10))
    x, y = data['tst_X'], data['tst_Y']
    mx = sess.run(n.encoder[-1], feed_dict={n.x: x})
    var = sess.run(n.noisevar)
    ax = plt.axes()
    for r in mx:
        c = plt.Circle((r[0], r[1]), radius=np.sqrt(var), fc='none', alpha=0.05, ec='k')
        ax.add_patch(c)
    plt.axis('scaled');