Skip to content

Commit

Permalink
Add early stop after 1st epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
acflorea committed Feb 26, 2018
1 parent 6e155e5 commit d44c90c
Showing 1 changed file with 50 additions and 15 deletions.
65 changes: 50 additions & 15 deletions cifar10-cnn.py
Expand Up @@ -37,7 +37,7 @@ def main(argumentList):

batch_size = int(getValue(argumentsDict, '-b', '--batch_size', 32))

epochs = int(argumentsDict.get('-e', argumentsDict.get('--epochs', 100)))
epochs = int(argumentsDict.get('-e', argumentsDict.get('--epochs', 10)))

data_augmentation = bool(argumentsDict.get('-a', argumentsDict.get('--augmentation', True)))

Expand Down Expand Up @@ -71,6 +71,9 @@ def main(argumentList):
conv_map = conv_map.split(',')
full_map = full_map.split(',')

# If accuracy after 1t epoch is below this limit then break the training
acc_break_limit = -0.15

# The data, shuffled and split between train and test sets:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print('x_train shape:', x_train.shape)
Expand Down Expand Up @@ -132,9 +135,24 @@ def main(argumentList):
print('Not using data augmentation.')
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
epochs=1,
validation_data=(x_test, y_test),
shuffle=True)

# Score trained model.
scores = model.evaluate(x_test, y_test)

if scores[1] > acc_break_limit:
# If the model looks promising....
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_test, y_test),
shuffle=True,
initial_epoch=1)
else:
print('[earlystop] Training stopped after 1st epoch!')

else:
print('Using real-time data augmentation.')
# This will do preprocessing and realtime data augmentation:
Expand All @@ -150,16 +168,35 @@ def main(argumentList):
horizontal_flip=True, # randomly flip images
vertical_flip=False) # randomly flip images

# Compute quantities required for feature-wise normalization
# (std, mean, and principal components if ZCA whitening is applied).
datagen.fit(x_train)
# Compute quantities required for feature-wise normalization
# (std, mean, and principal components if ZCA whitening is applied).
datagen.fit(x_train)

# fit for 1st epoch
model.fit_generator(datagen.flow(x_train, y_train,
batch_size=batch_size),
epochs=1,
validation_data=(x_test, y_test),
workers=4)

# Fit the model on the batches generated by datagen.flow().
model.fit_generator(datagen.flow(x_train, y_train,
batch_size=batch_size),
epochs=epochs,
validation_data=(x_test, y_test),
workers=4)
# Score trained model.
scores = model.evaluate(x_test, y_test)

if scores[1] > acc_break_limit:
# If the model looks promising....
# Fit the model on the batches generated by datagen.flow().
model.fit_generator(datagen.flow(x_train, y_train,
batch_size=batch_size),
epochs=epochs,
validation_data=(x_test, y_test),
workers=4,
initial_epoch=1)

# Score trained model.
scores = model.evaluate(x_test, y_test, verbose=1)

else:
print('[earlystop] Training stopped after 1st epoch!')

# Save model and weights
if not os.path.isdir(save_dir):
Expand All @@ -168,10 +205,8 @@ def main(argumentList):
model.save(model_path)
print('Saved trained model at %s ' % model_path)

# Score trained model.
scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
print('[results] Test accuracy:', scores[1])
print('[results] Test loss:', scores[0])

sys.stdout.write(str(scores[1]))
sys.stdout.flush()
Expand Down

0 comments on commit d44c90c

Please sign in to comment.