## Training the Oracle

This scripts assumes the use of `buckets classification` as classification modality.

In [None]:
from libs.generators.batch_generator_oracle import BatchGeneratorTrainOracle, BatchGeneratorValidationOracle
from libs.models.oracle_models import get_pretrained_oracle

In [None]:
import tensorflow as tf
import keras.backend as K
import keras

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)

In [None]:
gen_train = BatchGeneratorTrainOracle(batch_size = 16, vgg = True, modality= 'buckets')

In [None]:
gen_valid = BatchGeneratorValidationOracle(batch_size = 16, vgg = True, modality = 'buckets')

## Retrieve pre-trained model and train 

In [None]:
model = get_pretrained_oracle()

In [None]:
def distance_categorical(y_true, y_pred):
    '''Measures the distance between two classes. Used as selection criterion in validation.'''
    t = K.argmax(y_true)
    p = K.argmax(y_pred)
    return K.cast(K.abs(t-p), 'float32') / 5.

In [None]:
from keras.callbacks import ModelCheckpoint
from keras.optimizers import Adam

checkpoint = ModelCheckpoint('checkpoints/checkpoint_oracle.h5', monitor='val_distance_categorical', verbose=1, save_best_only=True, save_weights_only=True)
model.compile(loss='categorical_crossentropy', optimizer = Adam(lr=1e-4, clipvalue= 0.3, clipnorm= 1.), metrics = ['categorical_accuracy', distance_categorical])

In [None]:
model.fit_generator(gen_train, validation_data= gen_valid, max_queue_size=50, callbacks=[checkpoint], epochs=20, workers= 64, class_weight=[1.,1., 1., 1., 0.8, 0.5] )

## Unfreeze pre-trained layers and fine-tune

In [None]:
# unfreeze the pre-trained layers
for l in model.layers[:-6]:
    l.trainable = True

In [None]:
model.compile(loss='categorical_crossentropy', optimizer = Adam(lr=1e-5, clipvalue= 0.3, clipnorm= 1.), metrics = ['categorical_accuracy', distance_categorical])

In [None]:
model.fit_generator(gen_train, validation_data= gen_valid, max_queue_size=50, callbacks=[checkpoint], epochs=20, workers= 64, class_weight=[1.,1., 1., 1., 0.8, 0.5] )

## Save model

In [None]:
from libs.keras_checkpoints import save_model

model.load_weights('checkpoints/checkpoint_oracle.h5')
save_model(model, 'models/oracle')