In [None]:
# This notebook generated plots from results of do_run.py
%load_ext autoreload
%autoreload 2
%matplotlib inline
import numpy as np
import hashlib
import pickle, time, os

import matplotlib

import matplotlib.pyplot as plt
import seaborn as sns
matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'
sns.set(style='whitegrid', font='STIXGeneral', font_scale=1.25)


import tensorflow as tf
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"  # Whether to use GPU or not
tf.logging.set_verbosity(tf.logging.ERROR)
tfconfig = tf.ConfigProto()
tfconfig.gpu_options.allow_growth=True

import entropy, loaddata, mc_entropy, loaddata


In [None]:
def loadresults(sess, data, basedir, sfx, verbose=0, rerun=False, store_batch=False):
    Xbatch, saver = None, None
    vals = {}
    
    for fname in os.listdir(basedir):
        if not fname.endswith(sfx):
            continue

        fullfname = os.path.join(basedir, fname, 'data')
        if verbose > 0:
            print(fullfname)
        with open(fullfname, 'rb') as f:
            try:    
                cfg, saved_data = pickle.load(f)
            except: 
                print("Error loading pickle file") ; continue

        if not len(vals): # First time through this loop
            print("Doing", basedir, sfx)

        
        runtype = cfg['runtype']
        if cfg['runtype'] != data['runtype']:
            raise Exception('Wrong dataset')

        sqmode = cfg['objective'] 
        beta   = saved_data[0]['beta']
        method = saved_data[0]['method'] # method indicates 'ce' vs 'nlIB' vs 'VIB'
        backtrack = cfg['patience'] 
        
        for trntstmode in ['trn','tst']:
            val_dict = {'beta': beta}
            
            if len(saved_data) < backtrack:
                continue
                
            lrow = saved_data[-backtrack]
                
            var   = lrow['noisevar']
            l     = lrow[trntstmode]
            epoch = lrow['epoch']

            epoch_fname = '%s/%s/tf_model-%d' % (basedir, fname, epoch)
            epoch_fname_full = epoch_fname+'.data-00000-of-00001'

            if not os.path.exists(epoch_fname_full):
                print("ERROR: Cant find %s" % epoch_fname_full)
                print(backtrack, epoch, len(saved_data), saved_data[-1]['epoch'])
                continue

            h = hashlib.md5(open(epoch_fname_full,'rb').read()).hexdigest() + '-' + trntstmode
            cache_dir = '.mc_entropy_cache'
            if not os.path.exists(cache_dir):
                os.makedirs(cache_dir)

            cachefile_exists =  os.path.exists(cache_dir + '/'+h)
            if not rerun and cachefile_exists: # cache calculations
                with open(cache_dir + '/'+h, 'rb') as f:
                    l = pickle.load(f)
            else:
                if cachefile_exists:
                    os.remove(cache_dir + '/'+h)

                if saver is None:
                    saver = tf.train.import_meta_graph(tf.train.latest_checkpoint('%s/%s' % (basedir, fname))+'.meta')

                saver.restore(sess, epoch_fname)

                if Xbatch is None:
                    permutation  = np.random.permutation(len(data[trntstmode+'_X']))
                    Xbatch = data[trntstmode+'_X'][permutation[:2000]]
                    Ybatch = np.argmax(data[trntstmode+'_Y'][permutation[:2000]],axis=1)

                mx          = sess.run( 'noisy_ib_layer/RawInput:0' , feed_dict={'X:0':Xbatch})
                
                n, d        = mx.shape
                mcH         = mc_entropy.get_mc_entropy(mx, var)
                hCond       = entropy.gaussian_entropy_np(d, var)
                l['beta']   = beta
                l['Ixt_mc'] = mcH - hCond
                l['noisevar'] = var
                
                if store_batch:
                    l['hidden_acts'] = sess.run( 'noisy_ib_layer/RawInput:0' , feed_dict={'X:0':data[trntstmode+'_X']})
                    l['pred_y']      = data[trntstmode+'_Y']
                    if data['err'] != 'mse':
                        l['pred_y'] = np.argmax(data[trntstmode+'_Y'],axis=1)

                with open(cache_dir + '/'+h, 'wb') as f:
                    pickle.dump(l, f)

            val_dict = l

            if method == 'ce':
                val_dict['Ixt_mc'] = np.nan
                
            val_dict['err'] = data['err']
                
            if data['err'] == 'mse':
                hcond = 0.5*(np.log(2*np.pi*val_dict['ce']) + 1)
                l['Iyt'] = data[trntstmode+'_entropyY'] - hcond
                
            if verbose > 1:
                print('%s %s beta: %4g ce: %0.3f Ixt: %0.3f-%0.3f (%0.3f) Iyt: %0.3f' % 
                      (trntstmode, sqmode, beta, val_dict['ce'], val_dict['Ixt_lb'], val_dict['Ixt'], val_dict['Ixt_mc'], val_dict['Iyt']))

            k = (trntstmode, method, sqmode)
            if k not in vals:
                vals[k] = []
            vals[k].append(val_dict)
            
        if verbose > 1:
            print()
            
    del Xbatch, saver
            
    return vals



def plotresults(runtype, baseline, run_results, axes, methods=None, sqmodes=None, plotmode=0):
    MULT = 1./np.log(2)  # nats to bits
    
    def sort_and_interp(newx, oldx, oldy):
        ix = np.argsort(oldx)
        oldx = np.array(oldx)[ix]
        oldy = np.array(oldy)[ix]
        return np.interp(newx, oldx, oldy, left=np.nan, right=np.nan)
    
    colordict={('nlIB','reg'):'blue', ('nlIB','sq'):'blue', 
               ('VIB','reg'):'red', ('VIB','sq'):'red',
               ('VIBraw','reg'):'orange', ('VIBraw','sq'):'orange' }
    
    baselineIyt = {}


    svals = {}
    minX = min([entry['Ixt_mc']  for results in run_results.values() for run in results for entry in run])
    maxX = max([entry['Ixt_mc']  for results in run_results.values() for run in results for entry in run])
    newx     = np.linspace(minX, maxX, 30, endpoint=True) 
    
    for k, results in run_results.items():
        trntstmode, method, sqmode = k

        if not len(results):                              continue
        if sqmodes is not None and sqmode not in sqmodes: continue
        if methods is not None and method not in methods: continue
            
        plt.sca(axes[0 if trntstmode == 'trn' else 1])
        

        methodcolor = colordict[(method,sqmode)]
        linestyle   = '-'
        label       = {'nlIB':'nonlinear IB', 'VIB':'VIB'}[method]
                
        resamp_y = []
        for run in results:
            xs = [entry['Ixt_mc'] for entry in run]
            ys = [entry['Iyt']    for entry in run]
            resamp_y.append(sort_and_interp(newx, xs, ys))
        resamp_y = np.array(resamp_y)
        mean_y   = np.nanmean(resamp_y, axis=0) 
        # standard error of the mean
        errorbars = np.nanstd(resamp_y, axis=0) / np.sqrt(np.sum(~np.isnan(resamp_y), axis=0)) 
        
        ix = np.argmin( (newx - np.log(10))**2 )
        print(trntstmode, method, sqmode, "Mean I(Y;T) at I(X;T)=log 10: ", MULT*mean_y[ix])
        
        if plotmode == 0:
            plt.plot(MULT*newx, MULT*mean_y, ls=linestyle, color=methodcolor, label=label)
            if resamp_y.shape[0] > 1:
                plt.fill_between(MULT*newx, MULT*(mean_y - errorbars), 
                                 MULT*(mean_y + errorbars), color=methodcolor, alpha=0.3)
        else:
            for run in results:
                cX = np.array([entry['Ixt_mc'] for entry in run])
                cY = np.array([entry['Iyt']    for entry in run])
                ix = np.argsort(cX)
                plt.plot(MULT*cX[ix], MULT*cY[ix], marker='x', ls=linestyle, color=methodcolor, label=label)

        svals[k] = mean_y
        
    for trntstmode in ['trn','tst']:
        for sqmode in ['sq','reg']:
            if (trntstmode, 'ce', sqmode) in baseline:
                baselineIyt[trntstmode] = np.mean([r['Iyt'] for r in baseline[(trntstmode, 'ce', sqmode)]])
        
        plt.sca(axes[0 if trntstmode == 'trn' else 1])
                
        if trntstmode in baselineIyt:
            err = 'MSE' if runtype == 'Housing' else 'CE'
            plt.hlines(MULT*baselineIyt[trntstmode],  0, MULT*4, linestyles=':', label=err + ' only')
        
        plt.xlabel('$I(X;M)$ (bits)')
        if trntstmode == 'trn':
            plt.ylabel('$I(Y;M)$ (bits)')
        plt.legend(loc='lower right')
        plt.title(runtype + ' - ' + {'tst':'Testing', 'trn':'Training'}[trntstmode])

        ylims = plt.ylim()
        plt.plot([0, MULT*5], [0, MULT*5], '--k')
        
        for sqmode in ['reg','sq']:
            if sqmodes is not None and sqmode not in sqmodes: continue
            if trntstmode == 'tst' and sqmode == 'sq':
                plt.vlines(MULT*np.log(10),0, 5, color='green', alpha=0.5, zorder=-1)
                
            k1, k2 = (trntstmode,'nlIB',sqmode), (trntstmode,'VIB',sqmode)
            if k1 in svals and k2 in svals:
                perfdiff  = MULT*(svals[k1]-svals[k2])
                perfratio = svals[k1]/svals[k2]
                if not np.all(np.isnan(perfratio)):
                    print('max rel improvment %s %s: %g at %g' % (trntstmode, sqmode, np.nanmax(perfratio), newx[np.nanargmax(perfratio)]))
                if not np.all(np.isnan(perfdiff)):
                    print('max abs improvment %s %s: %g at %g' % (trntstmode, sqmode, np.nanmax(perfdiff) , newx[np.nanargmax(perfdiff)]))
                
        if runtype.endswith('MNIST'):
            plt.xlim([0, 6])
        else:
            plt.xlim([0, 5])
            
        plt.ylim([0, ylims[1]*1.1])
                    

    return axes


In [None]:
# IB curve plots

savefig    = False
maxruns    = 5
resultsdir = 'results8'
dirfilters = None # 
figdir     = 'outputpdf/' + resultsdir

run_results = {}
baselines   = {}
    
for runtype in [ 'MNIST','FashionMNIST','Housing', ]:
    data = loaddata.load_data(runtype, validation='True')
    
    pltaxes = {}
    for cdir in sorted(os.listdir(resultsdir)):
        
        if dirfilters is not None and not any(cdir.startswith(f) for f in dirfilters):
            continue
            
        basedir = os.path.join(resultsdir, cdir, runtype)
        if not os.path.exists(basedir):
            print(basedir, 'doesnt exist')
            continue
        print(basedir)

        tf.reset_default_graph()
        run_results[runtype] = {}
        with tf.Session(config=tfconfig) as sess:
            baselines[runtype] = loadresults(sess, data, basedir,'basemodel', store_batch=True)
            for run in range(maxruns):
                store_batch = run == 0
                res = loadresults(sess, data, basedir, 'run'+str(run), store_batch=store_batch)
                for k, v in res.items():
                    if k not in run_results[runtype]:
                        run_results[runtype][k] = []
                    run_results[runtype][k].append(v)
                
        for sqmode in ['sq','reg',]:# 'reg',]:
            if ('trn','nlIB',sqmode) not in run_results[runtype]:
                continue
                
            if sqmode not in pltaxes:
                # Share a X axis with each column of subplots
                _, pltaxes[sqmode] = plt.subplots(1, 2, figsize=(12,5), sharey='row')

            axes = plotresults(runtype, baselines[runtype], run_results[runtype], pltaxes[sqmode],
                               methods=['nlIB','VIB'], sqmodes=[sqmode,])
            if savefig:
                if figdir is None: raise Exception()
                    
                if not os.path.exists(figdir):
                    os.makedirs(figdir)
                pdfname = '%s/%s-%s-%s.pdf' % (figdir, runtype, cdir, sqmode)
                print('saving %s' % pdfname)
                plt.savefig(pdfname, bbox_inches='tight')
                
            else:
                plt.suptitle(cdir)
            

In [None]:
# Scatter plots

do_pca     = True
sqmode     = 'sq'
trntstmode = 'tst'
colormaps  = {'MNIST':'tab10','FashionMNIST':'tab10','Housing':'plasma'}

for runtype in ['MNIST','FashionMNIST','Housing']:
    todo = [(baselines[runtype][(trntstmode,'ce','sq')], 'Baseline'), 
            (run_results[runtype][(trntstmode,'VIB',sqmode)][0], 'VIB'), 
            (run_results[runtype][(trntstmode,'nlIB',sqmode)][0], 'nonlinear IB')]
        
    plt.figure(figsize=(15,4.5))
    for rndx, (r, method) in enumerate(todo):
        plt.subplot(1,3,rndx+1)

        plt.set_cmap(colormaps[runtype])
        
        if method =='Baseline':
            ix = 0
            methodname = 'Cross-entropy only' if runtype !='Housing' else 'MSE only'
        else:
            methodname = method
            ix = np.argmin([(l['Ixt_mc'] -np.log(10))**2 for l in r])


        d  = r[ix]
        print(method, 'IXT=%0.4f IYT=%0.4f' % (d['Ixt_mc'], d['Iyt']))
        mx, y = d['hidden_acts'], d['pred_y']

        if False:
            permutation  = np.random.permutation(len(mx))[0:10000]
            alpha = 1
            mx, y = mx[permutation], y[permutation]
        if do_pca:
            from sklearn.decomposition import PCA
            pca = PCA(n_components=2, whiten=True)
            pca.fit(mx)
            X_pca = pca.transform(mx)
            print('perc variance explained:', pca.explained_variance_ratio_)
        else:
            X_pca = mx

        plt.scatter(X_pca[:, 0], X_pca[:, 1], c=100*np.ravel(y).astype('int'),  s=.005)

        maxes = np.max(X_pca, axis=0)
        mins  = np.min(X_pca, axis=0)
        means = np.mean(X_pca, axis=0)
        xscale = 1.1*max([ maxes[0] - means[0] ,  mins[0] - means[0] ])
        yscale = 1.1*max([ maxes[1] - means[1] ,  mins[1] - means[1] ])
        
        plt.xlim([means[0] - xscale, means[0] + xscale])
        plt.ylim([means[1] - yscale, means[1] + yscale])
        

        plt.xticks([])
        plt.yticks([])
        plt.title(methodname)

    if savefig:
        pngname = '%s/activations-%s-%s-%s.png' % (figdir, runtype, cdir, sqmode)
        print('saving %s' % pngname)
        plt.savefig(pngname, bbox_inches='tight', dpi=300)
