# Comparing MC-Dropout and Moment Propagation

In [1]:
%load_ext autoreload

In [2]:
import os, sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split

In [3]:
BASE_PATH = os.path.join(os.getcwd(), "..", "..")
MODULES_PATH = os.path.join(BASE_PATH, "modules")
DATASET_PATH = os.path.join(BASE_PATH, "datasets")

In [4]:
sys.path.append(MODULES_PATH)

In [5]:
from bayesian import McDropout, MomentPropagation
from models import fchollet_cnn, setup_growth
from data import BenchmarkData, DataSetType

In [6]:
mnist = BenchmarkData(DataSetType.MNIST, os.path.join(DATASET_PATH, "mnist"), dtype=np.float32)

In [7]:
x_train, x_test, y_train, y_test = train_test_split(mnist.inputs, mnist.targets)

In [8]:
# Setup parameters and environment
setup_growth()
num_classes = len(np.unique(mnist.targets))
epochs = 120
batch_size = 80

# Create and fit model
base_model = fchollet_cnn(output=num_classes)
base_model.compile(optimizer="adadelta", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
base_model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size)

1 Physical GPU's,  1 Logical GPU's
Epoch 1/120
Epoch 2/120
Epoch 3/120
Epoch 4/120
Epoch 5/120
Epoch 6/120
Epoch 7/120
Epoch 8/120
Epoch 9/120
Epoch 10/120
Epoch 11/120
Epoch 12/120
Epoch 13/120
Epoch 14/120
Epoch 15/120
Epoch 16/120
Epoch 17/120
Epoch 18/120
Epoch 19/120
Epoch 20/120
Epoch 21/120
Epoch 22/120
Epoch 23/120
Epoch 24/120
Epoch 25/120
Epoch 26/120
Epoch 27/120
Epoch 28/120
Epoch 29/120
Epoch 30/120
Epoch 31/120
Epoch 32/120
Epoch 33/120
Epoch 34/120
Epoch 35/120
Epoch 36/120
Epoch 37/120
Epoch 38/120
Epoch 39/120
Epoch 40/120
Epoch 41/120
Epoch 42/120
Epoch 43/120
Epoch 44/120
Epoch 45/120
Epoch 46/120
Epoch 47/120
Epoch 48/120
Epoch 49/120
Epoch 50/120
Epoch 51/120
Epoch 52/120
Epoch 53/120
Epoch 54/120
Epoch 55/120
Epoch 56/120
Epoch 57/120
Epoch 58/120
Epoch 59/120
Epoch 60/120
Epoch 61/120
Epoch 62/120
Epoch 63/120
Epoch 64/120
Epoch 65/120
Epoch 66/120
Epoch 67/120
Epoch 68/120
Epoch 69/120
Epoch 70/120
Epoch 71/120
Epoch 72/120
Epoch 73/120
Epoch 74/120
Epoch 75/120

Epoch 82/120
Epoch 83/120
Epoch 84/120
Epoch 85/120
Epoch 86/120
Epoch 87/120
Epoch 88/120
Epoch 89/120
Epoch 90/120
Epoch 91/120
Epoch 92/120
Epoch 93/120
Epoch 94/120
Epoch 95/120
Epoch 96/120
Epoch 97/120
Epoch 98/120
Epoch 99/120
Epoch 100/120
Epoch 101/120
Epoch 102/120
Epoch 103/120
Epoch 104/120
Epoch 105/120
Epoch 106/120
Epoch 107/120
Epoch 108/120
Epoch 109/120
Epoch 110/120
Epoch 111/120
Epoch 112/120
Epoch 113/120
Epoch 114/120
Epoch 115/120
Epoch 116/120
Epoch 117/120
Epoch 118/120
Epoch 119/120
Epoch 120/120


<tensorflow.python.keras.callbacks.History at 0x7f42ac115370>

In [9]:
base_model.evaluate(x_test, y_test)



[0.15875177085399628, 0.9571428298950195]

# Compare Entropy values

In [10]:
comp_inputs = x_test[:100]
comp_targets = y_test[:100]

In [17]:
mc_model = McDropout(base_model)
mc_pred = mc_model(comp_inputs, sample_size=100)
print("Pred. shape: {} '(num_datapoints, sample_size, target_size)'".format(mc_pred.shape))

Pred. shape: (100, 100, 10) '(num_datapoints, sample_size, target_size)'


In [32]:
expectation = np.mean(mc_pred, axis=1)
expectation.shape

(100, 10)

In [75]:
selector_indices = comp_targets.reshape(100, 1)
selector_indices.shape

(100, 1)

In [84]:
expectation[selector_indices].shape

(100, 1, 10)

In [86]:
np.take(expectation, comp_targets, axis=1).shape

(100, 100)

In [90]:
comp_targets.shape

(100,)

In [88]:
expectation[comp_targets].shape

(100, 10)

In [92]:
expectation[selector_indices].shape

(100, 1, 10)

In [25]:
np.argmax(np.mean(mc_pred[0], axis=0))

0

In [96]:
np.take(expectation, comp_targets, axis=0).shape

(100, 10)

In [108]:
true_preds = np.zeros(len(comp_targets))
for i in range(len(comp_targets)):
    true_preds[i] =  expectation[i][comp_targets[i]]

array([0.75258821, 0.92882562, 0.76276612, 0.00594902, 0.86079055,
       0.99511909, 0.89682841, 0.94389063, 0.92224854, 0.54690146,
       0.88509905, 0.80546761, 0.94288337, 0.99008089, 0.65803999,
       0.93628335, 0.98372519, 0.99736273, 0.85688895, 0.45167312,
       0.54412693, 0.99694693, 0.99149042, 0.97506231, 0.90429068,
       0.96476281, 0.99376607, 0.01525678, 0.92910135, 0.80500817,
       0.9209888 , 0.96540415, 0.89284462, 0.99378961, 0.87577802,
       0.36374378, 0.98172319, 0.92099541, 0.98108208, 0.6535446 ,
       0.95879591, 0.18378532, 0.98516661, 0.93913531, 0.69626701,
       0.98560637, 0.96456993, 0.96768713, 0.92921185, 0.92214602,
       0.9149121 , 0.83855164, 0.95621765, 0.99028254, 0.97710383,
       0.98065734, 0.98156381, 0.79554468, 0.88364583, 0.82989639,
       0.97094488, 0.94502759, 0.77303004, 0.98848152, 0.98822248,
       0.98980784, 0.60438412, 0.98501283, 0.0503156 , 0.93848443,
       0.9610399 , 0.97675759, 0.16644867, 0.6491102 , 0.96987

In [14]:
mp_model = MomentPropagation(base_model)
mp_pred = mp_model(comp_inputs)