In [None]:
import keras
import numpy as np
import pandas as pd

from tqdm import tqdm_notebook

from keras import backend as K

In [None]:
# define metric
from balancedAccuracy import balancedAccuracy
num_classes = 3
bacc_metric = balancedAccuracy(num_classes)

In [None]:
weights = ["some model.h5", "other model.h5"]


In [None]:
data = np.load("HAMAUG.npz")

imageValList = data["imageValList"]
targetValList = data["targetValList"][:,:3]

testData = np.load("TESTHAM.npz")
testList = testData["testList"]
targetTestList = testData["targetTestList"]

In [None]:
# change preprocessing function depending on type of CNNs 
# (i.e. use different function for ResNet)

from keras.applications.densenet import preprocess_input 
imageValList = preprocess_input(imageValList)
testList = preprocess_input(testList)

In [None]:
# plot_confusion_matrix function
import itertools
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

class_names = ["MEL", "NV", "BKL"]

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(np.mean(np.diag(cm)))
        
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    
    # get balanced accuracy
    return np.mean(np.diag(cm))
    


In [None]:
def valMatrix():
    y_test = targetValList.copy()
    y_pred = model.predict(imageValList)
    y_pred = y_pred.argmax(1)
    y_test = y_test.argmax(1)

    cnf_matrix = confusion_matrix(y_test, y_pred)
    np.set_printoptions(precision=2)
    plt.figure()
    print("Balanced Accuracy: "+ str(plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
                          title='Normalized Validation confusion matrix')))
    plt.show()

def trainMatrix():
    y_test = targetListUltra.copy()
    y_pred = model.predict(imageListUltra)
    y_pred = y_pred.argmax(1)
    y_test = y_test.argmax(1)

    cnf_matrix = confusion_matrix(y_test, y_pred)
    np.set_printoptions(precision=2)
    plt.figure()
    print("Balanced Accuracy: "+ str(plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
                          title='Normalized Training confusion matrix')))
    plt.show()
    
def testMatrix():
    y_test = targetTestList.copy()
    y_pred = model.predict(testList)
    y_pred = y_pred.argmax(1)
    y_test = y_test.argmax(1)

    cnf_matrix = confusion_matrix(y_test, y_pred)
    np.set_printoptions(precision=2)
    plt.figure()
    print("Balanced Accuracy: "+ str(plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
                          title='Normalized Training confusion matrix')))
    plt.show()

In [None]:
for filename in tqdm_notebook(weights):
    model = keras.models.load_model(filename,
                               custom_objects={'balanced_acc':bacc_metric.balanced_acc})
    valPredictions = pd.DataFrame(model.predict(imageValList))
    testPredictions = pd.DataFrame(model.predict(testList))
    
    valPredictions.to_csv(filename[:-2]+"-validation.csv")
    testPredictions.to_csv(filename[:-2]+"-test.csv")
    
    valMatrix()
    
    del model
    K.clear_session()