In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np

import keras
from keras import optimizers
from keras.preprocessing.image import ImageDataGenerator
from keras.applications.resnet50 import preprocess_input

from resnet import get_resnet

Using TensorFlow backend.


In [3]:
data_dir = '/home/ubuntu/data/'

In [4]:
data_generator = ImageDataGenerator(
    data_format='channels_last',
    preprocessing_function=lambda x: preprocess_input(np.expand_dims(x, 0))
)

train_generator = data_generator.flow_from_directory(
    data_dir + 'train', 
    target_size=(224, 224),
    batch_size=64
)

val_generator = data_generator.flow_from_directory(
    data_dir + 'val', 
    target_size=(224, 224),
    batch_size=64
)

Found 25600 images belonging to 256 classes.
Found 5120 images belonging to 256 classes.


In [5]:
model = get_resnet()

for layer in model.layers[:-14]:
    layer.trainable = False
    
model.layers[-2].kernel_regularizer = keras.regularizers.l2(1e-3)
model.layers[-8].kernel_regularizer = keras.regularizers.l2(1e-3)
model.layers[-11].kernel_regularizer = keras.regularizers.l2(1e-3)
model.layers[-14].kernel_regularizer = keras.regularizers.l2(1e-3)

In [6]:
model.compile(
    optimizer=optimizers.Adam(lr=1e-4), 
    loss='categorical_crossentropy', metrics=['accuracy', 'top_k_categorical_accuracy']
)

In [7]:
model.fit_generator(
    train_generator, 
    steps_per_epoch=150, epochs=7, verbose=1,
    validation_data=val_generator, validation_steps=24, 
    max_queue_size=10, workers=1, use_multiprocessing=False
)

Epoch 1/7
Epoch 2/7
Epoch 3/7
Epoch 4/7
Epoch 5/7
Epoch 6/7
Epoch 7/7


<keras.callbacks.History at 0x7fd35c0570b8>

In [8]:
val_generator_no_shuffle = data_generator.flow_from_directory(
    data_dir + 'val', 
    target_size=(224, 224),
    batch_size=64, shuffle=False
)

Found 5120 images belonging to 256 classes.


In [9]:
model.evaluate_generator(val_generator_no_shuffle, 80)

[0.84789355546236034, 0.80488281250000004, 0.93398437499999998]

In [10]:
model.save_weights('resnet_weights.hdf5')