In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import pathlib
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)

import loaddata, iblayer, trainutils

# Dictionary to 
cfg = {}
cfg['squaredIB'] = True
cfg['n_batch']   = 128
cfg['n_epochs']  = 150
cfg['n_noisevar_batch'] = 1000 
cfg['report_every']  = 10

betavals = 10**np.linspace(-5, 0.1, 30, endpoint=True)

savedirbase  = str(pathlib.Path().absolute()) + '/saveddata-testthread/'
savedir = savedirbase + 'MNIST'

tfconfig = tf.ConfigProto()
tfconfig.gpu_options.allow_growth=True



  from ._conv import register_converters as _register_converters


In [2]:
class Network(object):
    def __init__(self, input_dim, output_dim):
        self.input_dim      = input_dim
        self.output_dim     = output_dim
        
        
        self.iblayerobj     = iblayer.NoisyIBLayer(n_noisevar_batch=cfg['n_noisevar_batch'], 
                                                   init_noisevar=0.01, init_kdewidth=-20)
        
        self.true_outputs   = tf.placeholder(tf.float32, [None,output_dim,], name='true_outputs')

        # TODO: build the network
        self.layers = []
        self.layers.append( tf.placeholder(tf.float32, [None,input_dim,], name='X') ) # tf.keras.layers.Input([input_dim,])
        self.layers.append( tf.keras.layers.Dense(512, activation=tf.nn.relu, name='encoder_0')(self.layers[-1]) )
        self.layers.append( tf.keras.layers.Dense(512, activation=tf.nn.relu, name='encoder_1')(self.layers[-1]) )
        self.layers.append( tf.keras.layers.Dense(2  , activation=None, name='encoder_2')(self.layers[-1]) )
        
        # tf.identity(self.layers[-1], name='T_nonoise')
        # # makes it easy to identify this layer by name 'T_nonoise' later
        
        self.layers.append( self.iblayerobj(self.layers[-1]) )
        self.layers.append( tf.keras.layers.Dense(512, activation=tf.nn.relu, name='decoder_0')(self.layers[-1]) )
        self.layers.append( tf.keras.layers.Dense(output_dim, activation=None, name='decoder_1')(self.layers[-1]) )

        self.inputs         = self.layers[0]
        self.predictions    = self.layers[-1]

        f                   = tf.nn.softmax_cross_entropy_with_logits_v2(labels=self.true_outputs, logits=self.predictions)
        self.cross_entropy  = tf.reduce_mean(f) # cross entropy
        self.accuracy       = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(self.predictions, 1), tf.argmax(self.true_outputs, 1)), tf.float32))

In [3]:
data           = loaddata.load_data('MNIST')
input_dim      = data['trn_X'].shape[1]
output_dim     = data['trn_Y'].shape[1]

In [21]:
# Test reloading
import os
cur_dir = 'tbtest/' # savedir+'/test2'
filelist = [ f for f in os.listdir(cur_dir)  if not f.startswith('.')]
for f in filelist:
    os.remove(os.path.join(cur_dir, f))
    
tf.reset_default_graph()
with tf.Session(config=tfconfig) as sess:
    print("Making base model")
    n = Network(input_dim, output_dim)
    sess.run(tf.global_variables_initializer())
    #print(n.iblayerobj.T_nonoise.name)
    #print( sess.run('noisy_ib_layer/RawInput:0', feed_dict={'X:0':data['trn_X']}) )
    #print( sess.run(n.iblayerobj.T_nonoise, feed_dict={'X:0':data['trn_X']}) )
    file_writer = tf.summary.FileWriter(cur_dir, sess.graph)
    #asdf
    cfg2 = cfg.copy()
    cfg2['n_epochs'] =0
    trainutils.train(sess, mode='ce', beta=0.0, net=n, cfg=cfg2, 
                data=data, savedir=cur_dir)
    print("Model saved in path: %s" % cur_dir)
    
    #for n in tf.get_default_graph().as_graph_def().node:
    #    if n.name.startswith('noisy_ib'):
    #        print(n.name)
    #print( sess.run('T_nonoise', feed_dict={'X:0':data['trn_X']}) )
    #print( sess.run('T_nonoise', feed_dict={'X:0':data['trn_X']}) )
    
    # print( sess.run(tf.identity(n.layers[3]), feed_dict={'X:0':data['trn_X']}) )
    
    file_writer = tf.summary.FileWriter(cur_dir, sess.graph)
    
    
    del n
    
# if False:
#     tf.reset_default_graph()
#     print(tf.train.latest_checkpoint(cur_dir))
#     with tf.Session() as sess:
#         saver = tf.train.import_meta_graph(cur_dir+'/tf_model-0.meta')
#         #print([n.name for n in tf.get_default_graph().as_graph_def().node])

#         #saver.restore(sess,  tf.train.latest_checkpoint(savedir+'/basemodel'))
#         #print([n.name for n in tf.get_default_graph().as_graph_def().node])
#         print( sess.run('decoder_0', feed_dict={'X:0':data['trn_X']}) )
#         #mx = sess.run(iblayername+':0', feed_dict={'X:0':Xbatch})
#asdf

Making base model
*** Saving to tbtest/ ***

mode: ce epoch: 0 | beta: 0.0000 | noisevar: 0.01 | kw: 2.06115e-09 | time/epoch: -
ce:  2.303/ 2.304 | acc:  0.125/ 0.120 | loss:  2.303/ 2.304 | 
Ixt:  3.360/ 3.410 | Ixt_lb:  2.043/ 2.090 | vIxt:  8.334/ 8.343 | Iyt: -0.000/-0.001 | 
Model saved in path: tbtest/


In [None]:
# Train the base model, without compression
tf.reset_default_graph()
with tf.Session(config=tfconfig) as sess:
    print("Making base model")
    n = Network(input_dim, output_dim)
    sess.run(tf.global_variables_initializer())
    trainutils.train(sess, mode='ce', beta=0.0, net=n, cfg=cfg, 
                data=data, savedir=savedir+'/basemodel')
    print("Model saved in path: %s" % savedir)
    del n
    

In [None]:
for mode in ['nlIB', 'VIB']:
    for beta in betavals:
        if np.isclose(beta, 0):
            continue
        tf.reset_default_graph()
        with tf.Session(config=tfconfig) as sess:
            n = Network(input_dim, output_dim)
            sess.run(tf.global_variables_initializer())
            tf.train.Saver().restore(sess, tf.train.latest_checkpoint(savedir+'/basemodel'))

            sqmode = 'sq'  if cfg['squaredIB'] else 'reg'
            runndx = 0
            savename = savedir + '/results-%s-%0.5f-%s-run%d' % (mode, beta, sqmode, runndx)
            trainutils.train(sess, mode=mode, beta=beta, net=n, cfg=cfg, data=data, savedir=savename,
                  optimization_callback=n.iblayerobj.optimize_eta)
            print("Model saved in path: %s" % savedir)
            del n
            print()
            print()

In [None]:
# plt.scatter(act[:,0], act[1,:], alpha=0.1)


In [None]:
betavals