In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
from tqdm import tqdm

import keras
from keras import optimizers
from keras.callbacks import ReduceLROnPlateau, EarlyStopping
from keras.preprocessing.image import ImageDataGenerator

from xception import Xception, preprocess_input

import matplotlib.pyplot as plt
%matplotlib inline

import sys
sys.path.append('../training_utils/')
from diagnostic_tools import top_k_accuracy, per_class_accuracy,\
    count_params, entropy, model_calibration, most_confused_classes,\
    most_inaccurate_k_classes, show_errors
    
from sklearn.metrics import accuracy_score, log_loss

In [None]:
data_dir = '/home/ubuntu/data/'

In [None]:
data_generator = ImageDataGenerator(
    rotation_range=30, 
    zoom_range=0.3,
    horizontal_flip=True, 
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.001,
    channel_shift_range=0.1,
    fill_mode='reflect',
    data_format='channels_last',
    preprocessing_function=preprocess_input
)

data_generator_val = ImageDataGenerator(
    data_format='channels_last',
    preprocessing_function=preprocess_input
)

train_generator = data_generator.flow_from_directory(
    data_dir + 'train_no_resizing', 
    target_size=(299, 299),
    batch_size=64
)

val_generator = data_generator_val.flow_from_directory(
    data_dir + 'val', shuffle=False,
    target_size=(299, 299),
    batch_size=64
)

# Model

In [None]:
model = Xception(weight_decay=1e-8)

# Training

In [None]:
model.compile(
    optimizer=optimizers.Adam(lr=1e-3), 
    loss='categorical_crossentropy', metrics=['accuracy', 'top_k_categorical_accuracy']
)

In [None]:
model.fit_generator(
    train_generator, 
    steps_per_epoch=266, epochs=30, verbose=1,
    callbacks=[
        #ReduceLROnPlateau(monitor='val_acc', factor=0.1, patience=2, epsilon=0.007),
        EarlyStopping(monitor='val_acc', patience=4, min_delta=0.01)
    ],
    validation_data=val_generator, validation_steps=80, workers=4
)

# Loss/epoch plots

In [None]:
plt.plot(model.history.history['loss'], label='train');
plt.plot(model.history.history['val_loss'], label='val');
plt.legend();
plt.xlabel('epoch');
plt.ylabel('loss');

In [None]:
plt.plot(model.history.history['acc'], label='train');
plt.plot(model.history.history['val_acc'], label='val');
plt.legend();
plt.xlabel('epoch');
plt.ylabel('accuracy');

In [None]:
plt.plot(model.history.history['top_k_categorical_accuracy'], label='train');
plt.plot(model.history.history['val_top_k_categorical_accuracy'], label='val');
plt.legend();
plt.xlabel('epoch');
plt.ylabel('top5_accuracy');

# Error analysis

### get human readable class names

In [None]:
# folder name to class name
decode = np.load('../preprocessing_utils/decode.npy')[()]
# folder name to index: val_generator.class_indices
# index to class name
decode = {val_generator.class_indices[k]: decode[int(k)] for k in val_generator.class_indices}

### get all predictions and all misclassified images 

In [None]:
val_predictions = [] 
val_true_targets = [] 
erroneous_samples = [] 
erroneous_targets = [] 
erroneous_predictions = [] 

for i, (x_batch, y_batch) in enumerate(val_generator, 1):
    preds = model.predict_on_batch(x_batch)
    
    val_predictions += [preds]
    val_true_targets += [y_batch.argmax(1)]
    
    miss = y_batch.argmax(1) != preds.argmax(1)
    erroneous_samples += [x_batch[miss]]
    erroneous_targets += [y_batch.argmax(1)[miss]]
    erroneous_predictions += [preds[miss]]
    
    if i >= 80:
        break
    
val_predictions = np.concatenate(val_predictions, axis=0)
val_true_targets = np.concatenate(val_true_targets, axis=0)
erroneous_samples = np.concatenate(erroneous_samples, axis=0)
erroneous_targets = np.concatenate(erroneous_targets, axis=0)
erroneous_predictions = np.concatenate(erroneous_predictions, axis=0)

### number of misclassified images (there are overall 5120 images)

In [None]:
n_errors = len(erroneous_targets)
n_errors

### logloss and different accuracies

In [None]:
log_loss(val_true_targets, val_predictions)

In [None]:
accuracy_score(val_true_targets, val_predictions.argmax(1))

In [None]:
print(top_k_accuracy(val_true_targets, val_predictions, k=(2, 3, 4, 5, 10)))

### entropy of predictions

In [None]:
hits = val_predictions.argmax(1) == val_true_targets

plt.hist(
    entropy(val_predictions[hits]), bins=30, 
    normed=True, alpha=0.7, label='correct prediction'
);
plt.hist(
    entropy(val_predictions[~hits]), bins=30, 
    normed=True, alpha=0.5, label='misclassification'
);
plt.legend();
plt.xlabel('entropy of predictions');

### probabilistic calibration of the model

In [None]:
model_calibration(val_true_targets, val_predictions, n_bins=10)

### per class accuracies

In [None]:
per_class_acc = per_class_accuracy(val_true_targets, val_predictions)
per_class_acc

In [None]:
plt.hist(per_class_acc);

In [None]:
A = most_inaccurate_k_classes(per_class_acc, 10, decode)[1]
A

### most confused classes

In [None]:
confused_pairs, _ = most_confused_classes(
    val_true_targets, val_predictions, decode, min_n_confusions=4
)
print(confused_pairs)

# Results

In [None]:
model.evaluate_generator(val_generator, 80)

In [None]:
#model.save_weights('xception_weights.hdf5')