## Notebook showing core architectural difference in terms of parameter count between 
### 1. Deep Ensmeble
### 2. Batch Ensmeble &
### 3. Rank-1 BNN



In [6]:

import tensorflow as tf
import tensorflow.keras.layers as tfkl
import edward2 as ed
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go

### Define the ensemble size

In [12]:
ENSEMBLE_SIZE=4

### Define the conventional MLP model

In [48]:
def get_model_conventional(num_hidden=1,hidden_units=50,name="convention_model"):
    # define the model
    x_input = tfkl.Input(shape=[1,],name="input_layer")
    
    dense=x_input
    for i in range(num_hidden):
        dense = tfkl.Dense(hidden_units, tf.nn.relu,name="dense_{}".format(i))(dense)
    out = tfkl.Dense(1,name="output_layer")(dense)

    model = tf.keras.Model(inputs=x_input, outputs=[out],name=name)
    return model

In [49]:
model=get_model_conventional(num_hidden=1,hidden_units=50)
print("num of params",model.count_params())
model.summary()

num of params 151
Model: "convention_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_layer (InputLayer)    [(None, 1)]               0         
                                                                 
 dense_0 (Dense)             (None, 50)                100       
                                                                 
 output_layer (Dense)        (None, 1)                 51        
                                                                 
Total params: 151
Trainable params: 151
Non-trainable params: 0
_________________________________________________________________


### Define the deep ensmeble model

In [50]:
def get_deep_ensemble(model_func,model_hyp,ensemble_size=ENSEMBLE_SIZE):
    models = []
    for i in range(ensemble_size):
        mdl=model_func(num_hidden=model_hyp[0],hidden_units=model_hyp[1],name="deep_ens_mem-{}".format(i))
        models.append(mdl)
    return models

In [61]:
models=get_deep_ensemble(get_model_conventional,(1,50),ensemble_size=ENSEMBLE_SIZE)
model_ensemble_params=0
for i in range(len(models)):
    print("====Member-{}====".format(i))
    num_params=models[i].count_params()
    print("num of params",num_params)
    print(models[i].summary())
    model_ensemble_params+=num_params
print("Toatal number of params in model ensemble: ",model_ensemble_params)
print("Ensemble_n*ENSEMBLE_SIZE: ",num_params*ENSEMBLE_SIZE)

====Member-0====
num of params 151
Model: "deep_ens_mem-0"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_layer (InputLayer)    [(None, 1)]               0         
                                                                 
 dense_0 (Dense)             (None, 50)                100       
                                                                 
 output_layer (Dense)        (None, 1)                 51        
                                                                 
Total params: 151
Trainable params: 151
Non-trainable params: 0
_________________________________________________________________
None
====Member-1====
num of params 151
Model: "deep_ens_mem-1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_layer (InputLayer)    [(None, 1)]               0         
                     

#### In deep ensmeble the difference is in the initilization of the models

In [52]:
w1=models[0].layers[1].weights
w2=models[1].layers[1].weights
print(w1)
print(w2)

[<tf.Variable 'dense_0/kernel:0' shape=(1, 50) dtype=float32, numpy=
array([[ 0.2572733 ,  0.2646002 , -0.08729094, -0.13128406,  0.33941826,
        -0.20077209, -0.13205856,  0.26357552,  0.24014375, -0.11084662,
        -0.27933776,  0.10607699, -0.1456709 ,  0.20059481,  0.05305547,
        -0.18292984,  0.02290425, -0.17155173, -0.05053297,  0.3277506 ,
         0.02543876,  0.25356475,  0.1238243 ,  0.28033963, -0.3055117 ,
        -0.3137288 , -0.21066244, -0.06668898, -0.18757641,  0.16638026,
        -0.32150245, -0.15873754,  0.19092748, -0.13500507,  0.00444245,
         0.23786333, -0.27174127,  0.15575588, -0.20113216,  0.33055732,
         0.0040195 , -0.1520585 , -0.06767324,  0.32948866,  0.3083227 ,
         0.03680754,  0.11554953, -0.22810438, -0.22667843,  0.23155841]],
      dtype=float32)>, <tf.Variable 'dense_0/bias:0' shape=(50,) dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0.

### Define the Batch Ensemble model
#### The rank-1 factor alpha and gamma are initialized using sign initializer

In [13]:

def make_sign_initializer(random_sign_init):
  if random_sign_init > 0:
    return ed.initializers.RandomSign(random_sign_init)
  else:
    return tf.keras.initializers.RandomNormal(mean=1.0,
                                              stddev=-random_sign_init)

In [55]:
def get_model_batchensemble(num_hidden=1,hidden_units=50,ensemble_size=ENSEMBLE_SIZE,name="batch_ens"):
    # define the model
    x_input = tfkl.Input(shape=[1,],name="input_layer")
    dense = x_input
    for i in range(num_hidden):
        #dense = ed.layers.DenseBatchEnsemble(units=hidden_units,rank=1,ensemble_size=ensemble_size,use_bias=True,)(dense)
                                      #alpha_initializer='ones',gamma_initializer='ones',)(dense)
        
        dense = ed.layers.DenseBatchEnsemble(units=hidden_units,rank=1,ensemble_size=ensemble_size,use_bias=True,
                                            alpha_initializer=make_sign_initializer(-0.5),
                                            gamma_initializer=make_sign_initializer(-0.5),name="dense_batch_{}".format(i))(dense)

    out = ed.layers.DenseBatchEnsemble(units=1,rank=1,ensemble_size=ensemble_size,use_bias=True,name="output_layer")(dense)
                                      #alpha_initializer='ones',gamma_initializer='ones',)(dense)

    model = tf.keras.Model(inputs=x_input, outputs=[out],name=name)
    return model

In [56]:
model_be = get_model_batchensemble(num_hidden=1,hidden_units=50)
print("num of params",model_be.count_params())
model_be.summary()

num of params 712
Model: "batch_ens"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_layer (InputLayer)    [(None, 1)]               0         
                                                                 
 dense_batch_0 (DenseBatchEn  (None, 50)               454       
 semble)                                                         
                                                                 
 output_layer (DenseBatchEns  (None, 1)                258       
 emble)                                                          
                                                                 
Total params: 712
Trainable params: 712
Non-trainable params: 0
_________________________________________________________________


### Define Rank-1 BNN

#### helper functions taken from : https://github.com/google/uncertainty-baselines/blob/8e1284ad1dfc11addcbdb8188f116db7424f0f98/uncertainty_baselines/models/rank1_bnn_utils.py#L30

In [75]:
def _make_sign_initializer(random_sign_init):
  if random_sign_init > 0:
    return ed.initializers.RandomSign(random_sign_init)
  else:
    return tf.keras.initializers.RandomNormal(mean=1.0,
                                              stddev=-random_sign_init)

def make_initializer(initializer, random_sign_init, dropout_rate):
  """Builds initializer with specific mean and/or stddevs."""
  if initializer == 'trainable_deterministic':
    return ed.initializers.TrainableDeterministic(
        loc_initializer=_make_sign_initializer(random_sign_init))
  elif initializer == 'trainable_half_cauchy':
    stddev_init = np.log(np.expm1(np.sqrt(dropout_rate / (1. - dropout_rate))))
    return ed.initializers.TrainableHalfCauchy(
        loc_initializer=_make_sign_initializer(random_sign_init),
        scale_initializer=tf.keras.initializers.Constant(stddev_init),
        scale_constraint='softplus')
  elif initializer == 'trainable_cauchy':
    stddev_init = np.log(np.expm1(np.sqrt(dropout_rate / (1. - dropout_rate))))
    return ed.initializers.TrainableCauchy(
        loc_initializer=_make_sign_initializer(random_sign_init),
        scale_initializer=tf.keras.initializers.Constant(stddev_init),
        scale_constraint='softplus')
  elif initializer == 'trainable_normal':
    stddev_init = np.log(np.expm1(np.sqrt(dropout_rate / (1. - dropout_rate))))
    return ed.initializers.TrainableNormal(
        mean_initializer=_make_sign_initializer(random_sign_init),
        stddev_initializer=tf.keras.initializers.TruncatedNormal(
            mean=stddev_init, stddev=0.1),
        stddev_constraint='softplus')
  elif initializer == 'trainable_log_normal':
    stddev_init = np.log(np.expm1(np.sqrt(dropout_rate / (1. - dropout_rate))))
    return ed.initializers.TrainableLogNormal(
        loc_initializer=_make_sign_initializer(random_sign_init),
        scale_initializer=tf.keras.initializers.TruncatedNormal(
            mean=stddev_init, stddev=0.1),
        scale_constraint='softplus')
  elif initializer == 'trainable_normal_fixed_stddev':
    return ed.initializers.TrainableNormalFixedStddev(
        stddev=tf.sqrt(dropout_rate / (1. - dropout_rate)),
        mean_initializer=_make_sign_initializer(random_sign_init))
  elif initializer == 'trainable_normal_shared_stddev':
    stddev_init = np.log(np.expm1(np.sqrt(dropout_rate / (1. - dropout_rate))))
    return ed.initializers.TrainableNormalSharedStddev(
        mean_initializer=_make_sign_initializer(random_sign_init),
        stddev_initializer=tf.keras.initializers.Constant(stddev_init),
        stddev_constraint='softplus')
  return initializer


def make_regularizer(regularizer, mean, stddev):
  """Builds regularizer with specific mean and/or stddevs."""
  if regularizer == 'normal_kl_divergence':
    return ed.regularizers.NormalKLDivergence(mean=mean, stddev=stddev)
  elif regularizer == 'log_normal_kl_divergence':
    return ed.regularizers.LogNormalKLDivergence(
        loc=tf.math.log(1.), scale=stddev)
  elif regularizer == 'normal_kl_divergence_with_tied_mean':
    return ed.regularizers.NormalKLDivergenceWithTiedMean(stddev=stddev)
  elif regularizer == 'cauchy_kl_divergence':
    return ed.regularizers.CauchyKLDivergence(loc=mean, scale=stddev)
  elif regularizer == 'normal_empirical_bayes_kl_divergence':
    return ed.regularizers.NormalEmpiricalBayesKLDivergence(mean=mean)
  elif regularizer == 'trainable_normal_kl_divergence_stddev':
    return ed.regularizers.TrainableNormalKLDivergenceStdDev(mean=mean)
  return regularizer

In [76]:
kl_annealing_epochs=200 #'Number of epoch over which to anneal the KL term to 1.')
alpha_initializer = 'trainable_normal' #'Initializer name for the alpha parameters.')
gamma_initializer = 'trainable_normal' #'Initializer name for the gamma parameters.')
alpha_regularizer = 'normal_kl_divergence' #'Regularizer name for the alpha parameters.')
gamma_regularizer = 'normal_kl_divergence' #'Regularizer name for the gamma parameters.')
use_additive_perturbation = False # 'Use additive perturbations instead of multiplicative.')
dropout_rate = 1e-3 # 'Dropout rate. Only used if alpha/gamma initializers are, e.g., trainable normal.')
prior_mean = 1. # 'Prior mean.')
prior_stddev = 0.1 #'Prior stddev. Sort of like a prior on dropout rate, where it encourages defaulting/shrinking to this value.')

random_sign_init = 0.5 #'Use random sign init for fast weights.')
fast_weight_lr_multiplier = 1.0 #'fast weights lr multiplier.')
num_eval_samples = 1 #'Number of model predictions to sample per example at eval time.')

In [83]:
def get_model_rank_1_bnn(num_hidden=1,hidden_units=50,ensemble_size=ENSEMBLE_SIZE,name="rank_1_bnn"):
    # define the model
    x_input = tfkl.Input(shape=[1,],name="input_layer")
    dense = x_input
    for i in range(num_hidden):
        #dense = ed.layers.DenseBatchEnsemble(units=hidden_units,rank=1,ensemble_size=ensemble_size,use_bias=True,)(dense)
                                      #alpha_initializer='ones',gamma_initializer='ones',)(dense)
            
         dense = ed.layers.DenseRank1(
              units=hidden_units,
              alpha_initializer=make_initializer(
                  alpha_initializer, random_sign_init, dropout_rate),
              gamma_initializer=make_initializer(
                  gamma_initializer, random_sign_init, dropout_rate),
              kernel_initializer='he_normal',
              activation=None,
              alpha_regularizer=make_regularizer(
                  alpha_regularizer, prior_mean, prior_stddev),
              gamma_regularizer=make_regularizer(
                  gamma_regularizer, prior_mean, prior_stddev),
              use_additive_perturbation=use_additive_perturbation,
              ensemble_size=ensemble_size,name="densem_rank_1_bnn_{}".format(i))(dense)
        
         #dense = ed.layers.DenseBatchEnsemble(units=hidden_units,rank=1,ensemble_size=ensemble_size,use_bias=True,
         #                                    alpha_initializer=make_sign_initializer(-0.5),
         #                                   gamma_initializer=make_sign_initializer(-0.5))(dense)
         dense = tf.keras.layers.Activation('relu',name="dense_{}".format(i))(dense)


    out = ed.layers.DenseRank1(
              units=1,
              alpha_initializer=make_initializer(
                  alpha_initializer, random_sign_init, dropout_rate),
              gamma_initializer=make_initializer(
                  gamma_initializer, random_sign_init, dropout_rate),
              kernel_initializer='he_normal',
              activation=None,
              alpha_regularizer=make_regularizer(
                  alpha_regularizer, prior_mean, prior_stddev),
              gamma_regularizer=make_regularizer(
                  gamma_regularizer, prior_mean, prior_stddev),
              use_additive_perturbation=use_additive_perturbation,
              ensemble_size=ensemble_size,name="output_layer")(dense)

    model = tf.keras.Model(inputs=x_input, outputs=[out],name=name)
    return model

In [84]:
model_r1_bnn = get_model_rank_1_bnn(num_hidden=2)
print("num of params",model_r1_bnn.count_params())
model_r1_bnn.summary()

num of params 4620
Model: "rank_1_bnn"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_layer (InputLayer)    [(None, 1)]               0         
                                                                 
 densem_rank_1_bnn_0 (DenseR  (None, 50)               658       
 ank1)                                                           
                                                                 
 dense_0 (Activation)        (None, 50)                0         
                                                                 
 densem_rank_1_bnn_1 (DenseR  (None, 50)               3500      
 ank1)                                                           
                                                                 
 dense_1 (Activation)        (None, 50)                0         
                                                                 
 output_layer (DenseRank1)   (None, 1

### Parameters comparison

In [145]:
def plot_parameter_increase(deep_ens, batch_ens, r1_bnn,hidden_units=50,save_plot=False,name="parameter_efficiency"):
    hid_layers=len(deep_ens)
    x=[x+1 for x in list(range(hid_layers))]
    fig = go.Figure()
    # Create and style traces
    fig.add_trace(go.Scatter(x=x, y=deep_ens, name='deep_ens',
                             line=dict(color='firebrick', width=4)))
    fig.add_trace(go.Scatter(x=x, y=batch_ens, name='batch_ens',
                             line=dict(color='green', width=4)))
    fig.add_trace(go.Scatter(x=x, y=r1_bnn, name='r1_bnn',
                             line=dict(color='royalblue', width=4)))
    plot_name='{}_{}'.format(name,hidden_units)
    
    fig.update_layout(
        title=plot_name,
        xaxis_title="num hidden units",
        yaxis_title="parameters increase(Xtimes)",)
    
    fig.show()
    if save_plot:
        fig.write_html('{}.html'.format(plot_name), auto_open=True)

#### hidden_units=50, varying num hidden layers 1 to 10

In [129]:
#hidden_units=50
num_params_single=[]
num_params_deep_ens=[]
num_params_batch_ens=[]
num_params_r1_bnn=[]

for hid in range(1,10):
    model_cl = get_model_conventional(num_hidden=hid,hidden_units=50)
    num_params_single.append(model_cl.count_params())
    num_params_deep_ens.append(model_cl.count_params()*ENSEMBLE_SIZE)
    model_be = get_model_batchensemble(num_hidden=hid,hidden_units=50)
    num_params_batch_ens.append(model_be.count_params())
    model_r1_bnn = get_model_rank_1_bnn(num_hidden=hid,hidden_units=50)
    num_params_r1_bnn.append(model_r1_bnn.count_params())

print("single model: ",num_params_single)
print("deep ensemble: ",num_params_deep_ens)
print("batch ensemble: ",num_params_batch_ens)
print("rank-1 bnn: ",num_params_r1_bnn)

increase_deep_ens=np.divide(num_params_deep_ens,num_params_single)
print("parameter increase for deep ensmeble: ",increase_deep_ens)
increase_batch_ens=np.divide(num_params_batch_ens,num_params_single)
print("parameter increase for batch ensmeble: ",increase_batch_ens)
increase_r1_bnn=np.divide(num_params_r1_bnn,num_params_single)
print("parameter increase for rank-1 bnn: ",increase_r1_bnn)

single model:  [151, 2701, 5251, 7801, 10351, 12901, 15451, 18001, 20551]
deep ensemble:  [604, 10804, 21004, 31204, 41404, 51604, 61804, 72004, 82204]
batch ensemble:  [712, 3812, 6912, 10012, 13112, 16212, 19312, 22412, 25512]
rank-1 bnn:  [1120, 4620, 8120, 11620, 15120, 18620, 22120, 25620, 29120]
parameter increase for deep ensmeble:  [4. 4. 4. 4. 4. 4. 4. 4. 4.]
parameter increase for batch ensmeble:  [4.71523179 1.41132914 1.3163207  1.2834252  1.26673751 1.25664677
 1.24988674 1.24504194 1.24139945]
parameter increase for rank-1 bnn:  [7.41721854 1.7104776  1.54637212 1.48955262 1.46072843 1.44329897
 1.43162255 1.42325426 1.41696268]


In [146]:
plot_parameter_increase(increase_deep_ens,increase_batch_ens,increase_r1_bnn,hidden_units=50)

#### hidden_units=500, varying num hidden layers 1 to 10

In [147]:
#hidden_units=500
num_params_single=[]
num_params_deep_ens=[]
num_params_batch_ens=[]
num_params_r1_bnn=[]

for hid in range(1,10):
    model_cl = get_model_conventional(num_hidden=hid,hidden_units=500)
    num_params_single.append(model_cl.count_params())
    num_params_deep_ens.append(model_cl.count_params()*ENSEMBLE_SIZE)
    model_be = get_model_batchensemble(num_hidden=hid,hidden_units=500)
    num_params_batch_ens.append(model_be.count_params())
    model_r1_bnn = get_model_rank_1_bnn(num_hidden=hid,hidden_units=500)
    num_params_r1_bnn.append(model_r1_bnn.count_params())
    
print("single model: ",num_params_single)
print("deep ensemble: ",num_params_deep_ens)
print("batch ensemble: ",num_params_batch_ens)
print("rank-1 bnn: ",num_params_r1_bnn)

increase_deep_ens=np.divide(num_params_deep_ens,num_params_single)
print("parameter increase for deep ensmeble: ",increase_deep_ens)
increase_batch_ens=np.divide(num_params_batch_ens,num_params_single)
print("parameter increase for batch ensmeble: ",increase_batch_ens)
increase_r1_bnn=np.divide(num_params_r1_bnn,num_params_single)
print("parameter increase for rank-1 bnn: ",increase_r1_bnn)

single model:  [1501, 252001, 502501, 753001, 1003501, 1254001, 1504501, 1755001, 2005501]
deep ensemble:  [6004, 1008004, 2010004, 3012004, 4014004, 5016004, 6018004, 7020004, 8022004]
batch ensemble:  [7012, 263012, 519012, 775012, 1031012, 1287012, 1543012, 1799012, 2055012]
rank-1 bnn:  [11020, 271020, 531020, 791020, 1051020, 1311020, 1571020, 1831020, 2091020]
parameter increase for deep ensmeble:  [4. 4. 4. 4. 4. 4. 4. 4. 4.]
parameter increase for batch ensmeble:  [4.6715523  1.04369427 1.03285765 1.02923104 1.02741502 1.02632454
 1.02559719 1.02507748 1.0246876 ]
parameter increase for rank-1 bnn:  [7.34177215 1.07547192 1.05675412 1.05048997 1.04735322 1.04546966
 1.04421333 1.04331564 1.04264221]


In [148]:
plot_parameter_increase(increase_deep_ens,increase_batch_ens,increase_r1_bnn,hidden_units=50)

#### hidden_units=1000, varying num hidden layers 1 to 10

In [149]:
#hidden_units=1000
num_params_single=[]
num_params_deep_ens=[]
num_params_batch_ens=[]
num_params_r1_bnn=[]

for hid in range(1,10):
    model_cl = get_model_conventional(num_hidden=hid,hidden_units=1000)
    num_params_single.append(model_cl.count_params())
    num_params_deep_ens.append(model_cl.count_params()*ENSEMBLE_SIZE)
    model_be = get_model_batchensemble(num_hidden=hid,hidden_units=1000)
    num_params_batch_ens.append(model_be.count_params())
    model_r1_bnn = get_model_rank_1_bnn(num_hidden=hid,hidden_units=1000)
    num_params_r1_bnn.append(model_r1_bnn.count_params())
    
print("single model: ",num_params_single)
print("deep ensemble: ",num_params_deep_ens)
print("batch ensemble: ",num_params_batch_ens)
print("rank-1 bnn: ",num_params_r1_bnn)

increase_deep_ens=np.divide(num_params_deep_ens,num_params_single)
print("parameter increase for deep ensmeble: ",increase_deep_ens)
increase_batch_ens=np.divide(num_params_batch_ens,num_params_single)
print("parameter increase for batch ensmeble: ",increase_batch_ens)
increase_r1_bnn=np.divide(num_params_r1_bnn,num_params_single)
print("parameter increase for rank-1 bnn: ",increase_r1_bnn)

single model:  [3001, 1004001, 2005001, 3006001, 4007001, 5008001, 6009001, 7010001, 8011001]
deep ensemble:  [12004, 4016004, 8020004, 12024004, 16028004, 20032004, 24036004, 28040004, 32044004]
batch ensemble:  [14012, 1026012, 2038012, 3050012, 4062012, 5074012, 6086012, 7098012, 8110012]
rank-1 bnn:  [22020, 1042020, 2062020, 3082020, 4102020, 5122020, 6142020, 7162020, 8182020]
parameter increase for deep ensmeble:  [4. 4. 4. 4. 4. 4. 4. 4. 4.]
parameter increase for batch ensmeble:  [4.6691103  1.02192328 1.01646433 1.01464105 1.01372872 1.01318111
 1.01281594 1.01255506 1.01235938]
parameter increase for rank-1 bnn:  [7.33755415 1.03786749 1.02843839 1.02528908 1.02371325 1.02276737
 1.02213662 1.02168602 1.02134802]


In [150]:
plot_parameter_increase(increase_deep_ens,increase_batch_ens,increase_r1_bnn,hidden_units=50)