In [1]:
import os
import cv2
import glob
import numpy as np
import pandas as pd
from keras.models import *
from keras.preprocessing.image import ImageDataGenerator
from keras.applications.xception import preprocess_input

Using TensorFlow backend.


In [2]:
# Set constant values
np.random.seed(2018)
model_name = "Xception_L20_0.28532"
model_image_size = (299, 299)
batch_size = 32
final_layer = 134
visual_layer = 131

In [None]:
# Find the image to visualize
def find_image(valid_label, right):
    if right:
        x = df['filename'][(df['y_valid'] == valid_label) & (df['y_pred'] == valid_label)]
    else:
        # x = df['filename'][(df['y_test']==true) &(df['y_pred']==pred)]
        x1 = df[(df['y_valid'] == valid_label) & (df['y_pred'] != valid_label)]
        x = x1['filename']
    print(x)
    if len(x):
        x = x.sample(n=1).values[0]
        return x
    else:
        return 0

In [None]:
# Draw heatmap in the image
def heatmap_image(model_show, weights_show, image_file):
    status = ["safe driving", " texting - right", "phone - right", "texting - left", "phone - left",
              "operation radio", "drinking", "reaching behind", "hair and makeup", "talking"]
    img = cv2.imread('.\\imgs\\new_valid\\' + image_file)
    img = cv2.resize(img, (model_image_size[1], model_image_size[0]))
    img_in = img.copy().astype(np.float32)
    img_in = preprocess_input(img_in)
    out, predictions = model_show.predict_on_batch(np.expand_dims(img_in, axis=0))
    predictions = predictions[0]
    out = out[0]

    pred_idx = np.argmax(predictions)
    pred = status[pred_idx]
    prediction = predictions[pred_idx]

    val_idx = int(image_file[image_file.rfind('\\') - 1])
    val = status[val_idx]

    cam = prediction * np.matmul(out, weights_show)
    cam = cam[:, :, pred_idx]
    cam -= cam.min()
    cam /= cam.max()

    cam = cv2.resize(cam, (model_image_size[1], model_image_size[0]))
    heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
    heatmap[np.where(cam <= 0.2)] = 0

    out = cv2.addWeighted(img, 0.8, heatmap, 0.4, 0)
    return out[:, :, ::-1], pred, val, prediction

In [3]:
# Data generator
gen = ImageDataGenerator(preprocessing_function=preprocess_input)

valid_generator = gen.flow_from_directory('.\\imgs\\new_valid',  
                                          target_size=model_image_size,
                                          shuffle=False, 
                                          batch_size=batch_size, 
                                          class_mode="categorical")

Found 2153 images belonging to 10 classes.


In [4]:
# Load model
model = load_model('.\\models\\'+model_name+'.h')

In [5]:
# Make prediction with the model
predict = []
y = []
i=0
for x_valid, y_valid in valid_generator:
    predict_batch = model.predict_on_batch(x_valid)
    predict.extend(predict_batch)
    y.extend(y_valid)
    i += 1
    if i== len(valid_generator):
        break
y = np.array(y, dtype = float)
predict = np.array(predict, dtype = float)

In [10]:
# Classify the right and wrong predictions
fname = valid_generator.filenames
y_ture = np.where(y > 0)[1]
y_pred = np.argmax(predict, axis=1)


d = {'filename':pd.Series(fname),
     'y_valid':pd.Series(y_ture),
     'y_pred':pd.Series(y_pred)}

df = pd.DataFrame(d)
df.head()

In [11]:
# Construct model2 for visualization
weights = model.layers[final_layer].get_weights()[0]
layer_output = model.layers[visual_layer].output
model2 = Model(model.input, [layer_output, model.output])

Unnamed: 0,filename,y_pred,y_test
2148,c9\img_9877.jpg,0,9
2149,c9\img_99104.jpg,9,9
2150,c9\img_993.jpg,0,9
2151,c9\img_99569.jpg,9,9
2152,c9\img_99949.jpg,9,9


In [14]:
# Draw heatmap for each class, and save
for status_num in range(10):
    for flag in [0,1]:
        x = find_image(status_num, flag)
        if x:
            img_heatmap, val, pred, prediction = heatmap_image(model2, weights, image_file=x)
            plt.figure()
            plt.title('Valid: %s | Pred: %s %.2f%%' %(val, pred, prediction * 100))
            plt.axis('off')
            plt.imshow(img_heatmap)
            #plt.show()
            plt.savefig('.\\visualization\\status'+str(status_num)+ '_'+str(flag)+'.png')
