In [166]:
import os

import keras

from classification_models.keras import Classifiers
from keras.preprocessing.image import ImageDataGenerator
from keras.layers import GlobalAveragePooling2D, Dense
from keras.models import Model
from keras.callbacks import ModelCheckpoint, CSVLogger, LearningRateScheduler
from keras.metrics import top_k_categorical_accuracy
from keras.optimizers import SGD

### Configurações de treino

In [167]:
def tl_scheduler(epoch):
    if epoch < 6:
        return 0.005
    elif epoch < 11:
        return 0.001
    elif epoch < 21:
        return 0.005
    elif epoch < 28:
        return 0.001
    elif epoch < 36:
        return 0.0005
    else:
        return 0.0001


In [168]:
training_config = {
    'batch_size': 128,
    'target_size': (224,224),
    'epochs': 1,
    'lr': 0.005,
    'decay': 0,
    'seed': 42
}

### Carregando Dataset

In [169]:
experiments_path = '/data/alberto/iWildCam2020/experiments'
experiment = 'resnet18_iwild205_21042020'

In [170]:
train_dir = '/data/alberto/iWildCam2020/resized/train_resized'
validation_dir = '/data/alberto/iWildCam2020/resized/validation_resized'

In [171]:
train_datagen = ImageDataGenerator(rotation_range=20,
                                   zoom_range=0.15,
                                   width_shift_range=0.2,
                                   height_shift_range=0.2,
                                   shear_range=0.15,
                                   horizontal_flip=True,
                                   fill_mode="nearest")

validation_datagen = ImageDataGenerator()

train_generator = train_datagen.flow_from_directory(
        train_dir,
        target_size=training_config['target_size'],
        batch_size=training_config['batch_size'],
        class_mode='categorical',
        shuffle=True,
        seed=training_config['seed'])

validation_generator = validation_datagen.flow_from_directory(
        validation_dir,
        target_size=training_config['target_size'],
        batch_size=training_config['batch_size'],
        class_mode='categorical',
        classes=sorted(list(train_generator.class_indices.keys())))

Found 203689 images belonging to 203 classes.
Found 56736 images belonging to 203 classes.


### Calbacks

In [172]:
model_path = os.path.join(experiments_path, experiment, 'models', 'weights.{epoch:02d}-{val_accuracy:.2f}.hdf5')
if not os.path.exists(os.path.dirname(model_path)):
    os.makedirs(os.path.dirname(model_path))

checkpoint = ModelCheckpoint(
    filepath=model_path,
    monitor='val_accuracy'
)

checkpoint_best = ModelCheckpoint(
    filepath=os.path.join(experiments_path, experiment, 'models', 'best.hdf5'),
    save_best_only=True,
    monitor='val_accuracy'
)

csv_path = os.path.join(experiments_path, experiment, 'history', 'training.log')
if not os.path.exists(os.path.dirname(csv_path)):
    os.makedirs(os.path.dirname(csv_path))

csv_logger = CSVLogger(csv_path, append=True)

lr_scheduler = LearningRateScheduler(tl_scheduler)

In [173]:
callbacks = [checkpoint, checkpoint_best, csv_logger, lr_scheduler]

### Definindo Métricas

In [174]:
def top_5_accuracy(y_true, y_pred):
    return top_k_categorical_accuracy(y_true, y_pred, k=5)

### Preparando modelo 

In [175]:
ResNet18, preprocess_input = Classifiers.get('resnet18')

In [176]:
base_model = ResNet18(input_shape=(224,224,3), weights='imagenet', include_top=False)
x = GlobalAveragePooling2D()(base_model.output)
output = Dense(len(validation_generator.class_indices), activation='softmax')(x)
model = Model(inputs=[base_model.input], outputs=[output])

In [177]:
for layer in model.layers:
    layer.trainable = True

In [178]:
optimizer = SGD(lr=training_config['lr'], momentum=0.9, decay=0)

In [179]:
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy', top_5_accuracy])

In [180]:
current_epoch = 0

### Treinando modelo

In [181]:
history = model.fit_generator(
            train_generator,
            steps_per_epoch=train_generator.n // training_config['batch_size'],
            epochs=training_config['epochs'],
            validation_data=validation_generator,
            validation_steps=validation_generator.n // training_config['batch_size'],
            callbacks=callbacks,
            initial_epoch=current_epoch,
            use_multiprocessing=True,
            workers=8)

Epoch 1/1
