In [None]:
from __future__ import print_function
%matplotlib inline
import matplotlib.pyplot as plt

import keras
import keras.backend as K

import numpy as np
import scipy.optimize
import scipy.misc
from collections import OrderedDict

import utils
trn, tst = utils.get_mnist()
PLOT_LAYERS    = [2,3]
NUM_EPOCHS     = 10000
NUMBER_OF_BINS = 3

SGD_BATCHSIZE = 128

#ACTIVATION = 'relu'
ACTIVATION = 'tanh'

In [None]:
import simplebinmi, kde

def wrapKfunc(f, data):
    def callf(logvar):
        return f([data, 1, np.exp(logvar)])[0].flat[0]
    return callf

class Reporter(keras.callbacks.Callback):
    def __init__(self, whentodo=None, *kargs, **kwargs):
        super(Reporter, self).__init__(*kargs, **kwargs)
        if whentodo is None:
            whentodo = lambda epoch: True
        
        self.whentodo = whentodo
        
    def on_train_begin(self, logs={}):
        self.layerfuncs = []
        self.layerweights = []
        
        var = K.placeholder(ndim=0)
        inputs = self.model.inputs + [ K.learning_phase(),] + [var,]
        for lndx in PLOT_LAYERS:
            l = self.model.layers[lndx]
            cfuncs = {}
            cfuncs['loo']    = K.function(inputs, [kde.kde_entropy_from_dists_loo(kde.Kget_dists(l.output),l.output,var)])
            cfuncs['upper']  = K.function(inputs, [kde.entropy_estimator(l.output,var)])
            #cfuncs['lower']  = K.function(inputs, [entropy_estimator(l.output,var, 0.25)])
            cfuncs['lower'] = K.function(inputs, [kde.entropy_estimator(l.output,4*var)+np.log(0.25)*l.output.shape[1]/2.])
            #cfuncs['lower']  = K.function(inputs, [entropy_estimator(l.output,var, 0.25)])
            cfuncs['output'] = K.function(self.model.inputs + [ K.learning_phase(),], [l.output])
            self.layerfuncs.append(cfuncs)
            self.layerweights.append(l.kernel)
            
        self.saved_logs = {}
        input_tensors = [model.inputs[0],
                         model.sample_weights[0],
                         model.targets[0],
                         K.learning_phase(),
        ]
        self.get_gradients = K.function(inputs=input_tensors, outputs=model.optimizer.get_gradients(model.total_loss, self.layerweights))
            
        self.get_loss = K.function(inputs=input_tensors, outputs=model.total_loss)

            
    def on_epoch_begin(self, epoch, logs={}):
        if not self.whentodo(epoch):
            self._log_gradients = False
        else:
            self._log_gradients = True
            self._batch_todo_ixs = {}
            self._batch_gradients = {}
            self._batch_weightnorm = []
            for cdata, cdataname in ((trn,'trn'), (tst, 'tst')):
                self._batch_gradients[cdataname] = [ [] for _ in PLOT_LAYERS ]
                ixs = list(range(len(cdata.X)))
                np.random.shuffle(ixs)
                self._batch_todo_ixs[cdataname] = ixs

            
    def on_batch_begin(self, batch, logs={}):
        if not self._log_gradients:
            return
        
        for cdata, cdataname, istrain in ((trn,'trn', 1), (tst, 'tst', 0)):
            cur_ixs = self._batch_todo_ixs[cdataname][:SGD_BATCHSIZE]
            if len(cur_ixs) < SGD_BATCHSIZE:
                continue
            inputs = [cdata.X[cur_ixs,:], [1,]*len(cur_ixs), cdata.Y[cur_ixs,:], istrain]
            for lndx, g in enumerate(self.get_gradients(inputs)):
                oneDgrad = np.reshape(g, -1, 1)
                self._batch_gradients[cdataname][lndx].append(oneDgrad)
                
            #for layerid in PLOT_LAYERS:
            #    w = K.get_value(self.model.layers[layerid].kernel)
            #    self._batch_weightnorm.append(np.linalg.norm(w))
                
            # Advance the indexing
            self._batch_todo_ixs[cdataname] = self._batch_todo_ixs[cdataname][SGD_BATCHSIZE:]
            
    def on_epoch_end(self, epoch, logs={}):
        if not self.whentodo(epoch):
            return
        
        model = self.model
        self._log_gradients = True
        
        l = OrderedDict()
        
        # Get overall performance
        for cdata, cdataname, istrain in ((trn,'trn',1), (tst, 'tst',0)):
            l['%s_loss'%cdataname] = self.get_loss([cdata.X, [1,]*len(cdata.X), cdata.Y, istrain]).flat[0]
            
        # Based on https://github.com/ravidziv/IDNNs/blob/1c4926f641d4306af7ae37325358be19e8f4d276/idnns/plots/plot_gradients.py
        for cdata, cdataname in ((trn,'trn'), (tst, 'tst')):
            for lndx, layerid in enumerate(PLOT_LAYERS):
                weights_norm = np.linalg.norm(K.get_value(model.layers[layerid].kernel))
                stackedgrads = np.stack(self._batch_gradients[cdataname][lndx], axis=1)
                #print(cdataname, lndx, stackedgrads.shape)
                gradmean = np.linalg.norm(stackedgrads.mean(axis=1))
                gradstd  = np.linalg.norm(stackedgrads.std(axis=1))
                l['%s_layer_%d_weightsnorm' % (cdataname, lndx)] = weights_norm
                l['%s_layer_%d_gradmean' % (cdataname, lndx)] = gradmean
                l['%s_layer_%d_gradstd' % (cdataname, lndx)]  = gradstd
                        
        for lndx, cfuncs in enumerate(self.layerfuncs):
            # Double check
            trndata = trn.X[::20]
            tstdata = tst.X[::10]
            r = scipy.optimize.minimize_scalar(wrapKfunc(cfuncs['loo'], trndata), method='brent')
            l['trn_layer_%d_h_loo'%lndx] = r.fun
            l['trn_layer_%d_logvar'%lndx] = r.x
            l['trn_layer_%d_h_upper'%lndx] = wrapKfunc(cfuncs['upper'], trndata)(r.x)
            l['trn_layer_%d_h_lower'%lndx] = wrapKfunc(cfuncs['lower'], trndata)(r.x)

            r = scipy.optimize.minimize_scalar(wrapKfunc(cfuncs['loo'], tstdata), method='brent')
            l['tst_layer_%d_h_loo' %lndx] = r.fun
            l['tst_layer_%d_logvar'%lndx] = r.x
            l['tst_layer_%d_h_upper'%lndx] = wrapKfunc(cfuncs['upper'], tstdata)(r.x)
            l['tst_layer_%d_h_lower'%lndx] = wrapKfunc(cfuncs['lower'], tstdata)(r.x)
            
            trndata = trn.X[::10]
            tstdata = tst.X
            trnlayeroutput = cfuncs['output']([trndata, 0])[0]
            tstlayeroutput = cfuncs['output']([tstdata, 0])[0]
            l['trn_layer_%d_h_bin'%lndx] = simplebinmi.bin_calc_information(trndata, trnlayeroutput, num_of_bins=NUMBER_OF_BINS)
            l['tst_layer_%d_h_bin'%lndx] = simplebinmi.bin_calc_information(tstdata, tstlayeroutput, num_of_bins=NUMBER_OF_BINS)
            l['trn_layer_%d_h_binstd'%lndx] = simplebinmi.bin_calc_information(trndata,0.5* scipy.stats.zscore(trnlayeroutput), num_of_bins=NUMBER_OF_BINS)
            l['tst_layer_%d_h_binstd'%lndx] = simplebinmi.bin_calc_information(tstdata,0.5* scipy.stats.zscore(tstlayeroutput), num_of_bins=NUMBER_OF_BINS)

        for k,v in l.items():
            print(k,"=",v)
            logs[k] = v

        self.saved_logs[epoch] = l.copy()
        
            
input_layer  = keras.layers.Input((trn.X.shape[1],))
hidden_output = keras.layers.Dense(1024, activation=ACTIVATION)(input_layer)
hidden_output = keras.layers.Dense(5  , activation=ACTIVATION)(hidden_output)
hidden_output = keras.layers.Dense(5  , activation=ACTIVATION)(hidden_output)

outputs  = keras.layers.Dense(trn.nb_classes, activation='softmax')(hidden_output)
model = keras.models.Model(inputs=input_layer, outputs=outputs)
optimizer = keras.optimizers.SGD(lr=0.001)
model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

def do_report(epoch):
    if epoch < 20:
        return True
    elif epoch < 100:
        return (epoch % 2 == 0)
    else:
        return (epoch % 20 == 0)
    
reporter   = Reporter(whentodo=do_report)

In [None]:
r = model.fit(x=trn.X, y=trn.Y, verbose=2, batch_size=SGD_BATCHSIZE, epochs=NUM_EPOCHS, 
              validation_data=(tst.X, tst.Y), callbacks=[reporter,])

In [None]:
epochs = sorted(reporter.saved_logs.keys())


plt.figure(figsize=(15,5*len(PLOT_LAYERS)))
import matplotlib.gridspec as gridspec
gs = gridspec.GridSpec(len(PLOT_LAYERS)+1, 2, )

for colndx, t in enumerate(['trn','tst']):
    plt.subplot(gs[0,colndx])
    plt.plot(epochs, [reporter.saved_logs[epoch][t+'_loss'] for epoch in epochs])
    plt.ylim([0, plt.ylim()[1]])
    plt.ylabel('Cross entropy loss')
    
for rowndx, lndx in enumerate(PLOT_LAYERS):
    for colndx, t in enumerate(['trn','tst']):
        plt.subplot(gs[1+rowndx,colndx])
        epochs = sorted(reporter.saved_logs.keys())

        plt.plot(epochs, [reporter.saved_logs[epoch][t+'_layer_%d_h_upper' % rowndx] for epoch in epochs], 'r:', label="$H_{KL}$")
        plt.plot(epochs, [reporter.saved_logs[epoch][t+'_layer_%d_h_lower' % rowndx] for epoch in epochs], 'g:', label="$H_{BD}$")
        
        # plt.plot(epochs, [reporter.saved_logs[epoch][t+'_layer_%d_h_loo'   % rowndx] for epoch in epochs], 'r', label="$H_{loo}$")
        plt.plot(epochs, [reporter.saved_logs[epoch][t+'_layer_%d_h_bin'   % rowndx] for epoch in epochs], 'b', label="$H_{bin}$")
        plt.plot(epochs, [reporter.saved_logs[epoch][t+'_layer_%d_h_binstd'% rowndx] for epoch in epochs], 'k', label="$H_{binstd}$")

        plt.title('%s layer %d'%(t,lndx))
        plt.xlabel('Epochs')
        plt.ylabel('Layer Activity Entropy')
        
        if rowndx == 1 and t == 'tst':
            plt.legend(loc='lower right')
plt.tight_layout()

In [None]:
import seaborn as sns
sns.set_style("white")

plt.figure(figsize=(8,4))
gs = gridspec.GridSpec(1,len(PLOT_LAYERS))
epochs = sorted(reporter.saved_logs.keys())
for lndx, layerid in enumerate(PLOT_LAYERS):
    plt.subplot(gs[0,lndx])
    means = np.array([reporter.saved_logs[epoch]['trn_layer_%d_gradmean' % lndx] for epoch in epochs])
    stds  = np.array([reporter.saved_logs[epoch]['trn_layer_%d_gradstd' % lndx] for epoch in epochs])
    plt.plot(epochs, means, 'b', label="Mean")
    plt.plot(epochs, stds, 'orange', label="Std")
    plt.plot(epochs, means/stds, 'red', label="SNR")
    plt.plot(epochs, [reporter.saved_logs[epoch][t+'_layer_%d_weightsnorm' % lndx] for epoch in epochs], 'g', label="||W||")
    
    #plt.ylabel('Layer %d - SNR'%rowndx)
    #plt.subplot(gs[rowndx*2+1,colndx])
    #plt.plot(epochs, [reporter.saved_logs[epoch][t+'_layer_%d_weightsnorm' % rowndx] for epoch in epochs], 'b', label="Weight norm")
    #plt.ylabel('Layer %d - ||Weights||'%rowndx)
    #plt.xlabel('Epochs')
    plt.title('Layer %d'%layerid)
    plt.gca().set_xscale("log", nonposx='clip')
    plt.gca().set_yscale("log", nonposy='clip')
    
plt.legend(loc='lower left')
plt.tight_layout()

#plt.savefig('run_output_%s.pdf'%ACTIVATION)

In [None]:

asdadsf
epochs = sorted(reporter.saved_logs.keys())
cm = plt.cm.get_cmap('inferno')

for colndx, t in enumerate(['trn','tst']):
    loss = [reporter.saved_logs[epoch][t+'_loss'] for epoch in epochs]
    plt.figure()
    for lndx, layerid in enumerate(PLOT_LAYERS):
        upperh = [reporter.saved_logs[epoch][t+'_layer_%d_h_upper' % lndx] for epoch in epochs]
        lowerh = [reporter.saved_logs[epoch][t+'_layer_%d_h_lower' % lndx] for epoch in epochs]
        sc=plt.scatter(loss, upperh, c=epochs, cmap=cm, edgecolor='none', label="$H_{KL}$")
        plt.scatter(loss, lowerh, c=epochs, cmap=cm, edgecolor='none', label="$H_{BD}$")
    plt.colorbar(sc, label='Epoch')
    plt.xlabel('Cross-entropy loss')
    plt.ylabel('H(hidden layer))')

In [None]:
sortedk = sorted(reporter.saved_batch_logs.keys())
meangrads = np.array([reporter.saved_batch_logs[k][0] for k in sortedk])
stdgrads  = np.array([reporter.saved_batch_logs[k][1] for k in sortedk])

plt.plot(sortedk, meangrads, label='m')
plt.hold('on')
plt.plot(sortedk, stdgrads, label='std')


asdfassdfsfd
