Join GitHub today
GitHub is home to over 28 million developers working together to host and review code, manage projects, and build software together.Sign up
Keras/CNTK: BatchNormalization layer causes predict() values to be incorrect #1994
Using a BatchNormalization Layer during the training of a Keras model will cause the output to assign an equal probability for each class, which is bad.
Does not reproduce when running the same script in Keras w/ TensorFlow backend, or when commenting-out the BatchNormalization() layer and rerunning in Keras/CNTK.
import keras from keras.datasets import mnist from keras.models import Model from keras.layers import Dense, Dropout, Input, BatchNormalization from keras.optimizers import RMSprop import numpy as np batch_size = 128 num_classes = 10 epochs = 1 # the data, shuffled and split between train and test sets (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train = x_train.reshape(60000, 784) x_test = x_test.reshape(10000, 784) x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 print(x_train.shape, 'train samples') print(x_test.shape, 'test samples') # convert class vectors to binary class matrices y_train = keras.utils.to_categorical(y_train, num_classes) y_test = keras.utils.to_categorical(y_test, num_classes) main_input = Input(shape=(784,)) hidden_1 = Dense(512, activation='relu')(main_input) hidden_1 = BatchNormalization()(hidden_1) hidden_1 = Dropout(0.2)(hidden_1) hidden_2 = Dense(512, activation='relu')(hidden_1) hidden_2 = Dropout(0.2)(hidden_2) main_output = Dense(10, activation='softmax', name="main_out")(hidden_2) model = Model(inputs=main_input, outputs=main_output) model.summary() model.compile(loss='categorical_crossentropy', optimizer=RMSprop(), metrics=['accuracy']) history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(x_test, y_test)) print(model.predict(x_test[0:1]))
Classes have near equal probability.
One class has probability near 1; the rest near 0.