In [1]:
%load_ext autoreload
%autoreload 2

In [8]:
import numpy as np

import keras
from keras import optimizers
from keras.preprocessing.image import ImageDataGenerator
from keras.applications.resnet50 import preprocess_input

from resnet import get_resnet

In [3]:
data_dir = '/home/ubuntu/data/'

In [4]:
data_generator = ImageDataGenerator(
    data_format='channels_last',
    preprocessing_function=lambda x: preprocess_input(np.expand_dims(x, 0))
)

train_generator = data_generator.flow_from_directory(
    data_dir + 'train', 
    target_size=(224, 224),
    batch_size=64
)

val_generator = data_generator.flow_from_directory(
    data_dir + 'val', 
    target_size=(224, 224),
    batch_size=64
)

Found 25600 images belonging to 256 classes.
Found 5120 images belonging to 256 classes.


In [22]:
model = get_resnet()

for layer in model.layers[:-2]:
    layer.trainable = False
    
model.layers[-2].kernel_regularizer = keras.regularizers.l2(1e-2)

In [24]:
model.compile(
    optimizer=optimizers.Adam(lr=1e-3), 
    loss='categorical_crossentropy', metrics=['accuracy', 'top_k_categorical_accuracy']
)

In [25]:
model.fit_generator(
    train_generator, 
    steps_per_epoch=150, epochs=10, verbose=2,
    validation_data=val_generator, validation_steps=16, 
    max_queue_size=10, workers=1, use_multiprocessing=False
)

Epoch 1/10
35s - loss: 2.3956 - acc: 0.5335 - top_k_categorical_accuracy: 0.7130 - val_loss: 1.1560 - val_acc: 0.7256 - val_top_k_categorical_accuracy: 0.9053
Epoch 2/10
34s - loss: 0.6579 - acc: 0.8544 - top_k_categorical_accuracy: 0.9584 - val_loss: 0.9891 - val_acc: 0.7539 - val_top_k_categorical_accuracy: 0.9248
Epoch 3/10
34s - loss: 0.4937 - acc: 0.8850 - top_k_categorical_accuracy: 0.9697 - val_loss: 0.8352 - val_acc: 0.7871 - val_top_k_categorical_accuracy: 0.9404
Epoch 4/10
34s - loss: 0.3455 - acc: 0.9172 - top_k_categorical_accuracy: 0.9837 - val_loss: 0.7942 - val_acc: 0.7979 - val_top_k_categorical_accuracy: 0.9316
Epoch 5/10
34s - loss: 0.1665 - acc: 0.9722 - top_k_categorical_accuracy: 0.9959 - val_loss: 0.8457 - val_acc: 0.7979 - val_top_k_categorical_accuracy: 0.9395
Epoch 6/10
34s - loss: 0.1580 - acc: 0.9684 - top_k_categorical_accuracy: 0.9954 - val_loss: 0.8162 - val_acc: 0.8105 - val_top_k_categorical_accuracy: 0.9365
Epoch 7/10
34s - loss: 0.1079 - acc: 0.9814 - 

<keras.callbacks.History at 0x7f857bd87278>

In [26]:
model.save_weights('resnet_weights.hdf5')