In [None]:
from __future__ import print_function

import os
import sys
import numpy as np

from keras.optimizers import SGD
from keras.callbacks import CSVLogger, ModelCheckpoint

sys.path.append(os.path.join(os.getcwd(), os.pardir))

import config

from utils.dataset.data_generator import DataGenerator
from models.cnn3_with_normalization import cnn_w_normalization

In [None]:
lr=0.01
n_epochs=500
batch_size=32
input_shape=(140, 140, 3)

name = 'cnn_140_rgb_lr_%f' % lr

In [None]:
print('loading model...')
model = cnn_w_normalization(input_shape=input_shape)
model.summary()

optimizer = SGD(lr=lr)

print('compiling model...')
model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
print('done.')

csv_logger = CSVLogger('%s_training.log' % name)
best_model_checkpointer = ModelCheckpoint(filepath=("./%s_training_weights_best.hdf5" % name), verbose=1,
                                          save_best_only=True)

current_model_checkpointer = ModelCheckpoint(filepath=("./%s_training_weights_best.hdf5" % name), verbose=0)

In [None]:
print('Initializing data generators...')
train_data_gen = DataGenerator(dataset_file=config.train_data_file, batch_size=batch_size)
validation_data_gen = DataGenerator(dataset_file=config.validation_data_file, batch_size=batch_size)
test_data_gen = DataGenerator(dataset_file=config.test_data_file, batch_size=batch_size)
print('done.')

In [None]:
print('Fitting model...')
history = model.fit_generator(train_data_gen,
                              nb_epoch=n_epochs,
                              samples_per_epoch=train_data_gen.n_batches * batch_size,
                              validation_data=validation_data_gen,
                              nb_val_samples=validation_data_gen.n_samples,
                              verbose=1,
                              callbacks=[csv_logger, best_model_checkpointer, current_model_checkpointer])
print('done.')

In [None]:
print('Evaluating model...')
score = model.evaluate_generator(test_data_gen, val_samples=test_data_gen.n_samples)
print('done.')

print('Test score:', score[0])
print('Test accuracy:', score[1])