In [None]:
import time
import os
import numpy as np
import tensorflow as tf
import importlib
from datetime import datetime
from tensorflow.python.framework.ops import reset_default_graph
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

import utils

SAVER_PATH = {'base': 'train/',
              'checkpoint': 'checkpoints/',
              'log': 'logs/',
              'test': 'test/'}
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")

def load_config(config_name):
    print("loading,", config_name)
    config_path = 'configurations.' + config_name
    config = importlib.import_module(config_path)
    return config

def load_path(config_name, epoch=None):
    name = config_name
    local_path = os.path.join(SAVER_PATH['base'], name)
    checkpoint_saver = tf.train.Saver()
    checkpoint_path = os.path.join(local_path, SAVER_PATH['checkpoint'])
    checkpoint_file_path = os.path.join(checkpoint_path, 'checkpoint')
    print(checkpoint_file_path)
    if epoch is None:
        latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
    else:
        latest_checkpoint = "%s-%d" % (checkpoint_file_path, epoch)
    return checkpoint_saver, latest_checkpoint

def validate(sess, gen, prediction, X_input, X_length, test=False):
    outs = []
    targets = []
    masks = []
    v_sum = 0
    for batch, i in gen():
        fetches = [prediction]
        feed_dict = {X_input: batch['X'], X_length: batch['length'], is_training_pl: False}
        out = sess.run(fetches=fetches, feed_dict=feed_dict)[0]
        h_out = np.zeros((i, 700, 8), dtype="float32")
        h_out[:, :out.shape[1], :] = out
        h_mask = np.zeros((i, 700), dtype="float32")
        h_mask[:, :out.shape[1]] = batch['mask']
        h_targets = np.zeros((i, 700), dtype="int32")
        h_targets[:, :out.shape[1]] = batch['t']
        outs.append(h_out)
        targets.append(h_targets)
        masks.append(h_mask)
        v_sum += i
    if test:
        v_sum = 514
    outs = np.concatenate(outs, axis=0)[:v_sum]
    targets = np.concatenate(targets, axis=0)[:v_sum]
    masks = np.concatenate(masks, axis=0)[:v_sum]
    accs = utils.proteins_acc(outs, targets, masks)
    return accs, outs, targets, masks

In [None]:
name_epochs = [("plain", 0, [851, 951, 1001, 1051, 901]),
               ("plain", 1, [1001, 1051, 1151, 651, 1401]),
               ("plain_bn", 0, [701, 751, 501, 601, 851]),
               ("plain_bn", 1, [751, 801, 601, 851, 651]),
               ("plain_dropout", 0, [1401, 1351, 1151, 1551, 1601]),
               ("plain_dropout", 1, [1151, 1251, 1351, 1451, 1101]),
               ("plain_bn_dropout", 0, [951, 801, 1151, 901, 1001]),
               ("plain_bn_dropout", 1, [1001, 1051, 901, 1301, 1101])]

meta_data = []
for config_name, crf_on, epochs in name_epochs:
    print("%s-%d" % (config_name, crf_on))
    total_valid_accs = []
    total_valid_outs = []
    total_test_accs = []
    total_test_outs = []
    for epoch in epochs:
        print("  %d" % epoch)
        reset_default_graph()
        config = load_config(config_name)
        data_gen = config.data_gen
        X_input, X_length, t_input, t_input_hot, t_mask, is_training_pl, \
            prediction, loss, accuracy, train_op, global_step = config.model(crf_on)
        load_name = "%s-%d" % (config_name, crf_on)
        checkpoint_saver, latest_checkpoint = load_path(load_name, epoch)
        print(load_name)
        print(latest_checkpoint)
        gpu_opts = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
        with tf.Session(config=tf.ConfigProto(gpu_options=gpu_opts)) as sess:
            checkpoint_saver.restore(sess, latest_checkpoint)
            accs, outs, targets, masks = validate(sess, data_gen.gen_valid, prediction, X_input, X_length)
            print("  valid accs: %f" % accs)
            total_valid_accs.append(accs)
            total_valid_outs.append(outs)
            valid_targets = targets
            valid_masks = masks
            accs, outs, targets, masks = validate(sess, data_gen.gen_test, prediction, X_input, X_length, test=True)
            print("  test accs: %f" % accs)
            total_test_accs.append(accs)
            total_test_outs.append(outs)
            test_targets = targets
            test_masks = masks
    total_valid_accs = np.mean(total_valid_accs)
    total_test_accs = np.mean(total_test_accs)
    valid_john = total_valid_outs.pop()
    for tot_out in total_valid_outs:
        valid_john += tot_out
    valid_john = valid_john/5.0
    valid_avrg_accs = utils.proteins_acc(valid_john, valid_targets, valid_masks)
    test_john = total_test_outs.pop()
    for tot_out in total_test_outs:
        test_john += tot_out
    test_john = test_john/5.0
    test_avrg_accs = utils.proteins_acc(test_john, test_targets, test_masks)
    print("------------------------")
    print("Average valid accuracy: %f" % total_valid_accs)
    print("Ensemble valid accuracy: %f" % valid_avrg_accs)
    print("Average test accuracy: %f" % total_test_accs)
    print("Ensemble test accuracy: %f" % test_avrg_accs)
    print()
    meta_data.append((config_name, crf_on, total_valid_accs, valid_avrg_accs, total_test_accs, test_avrg_accs))


In [None]:
print("name \t avrg. valid acc \t\t Ensemble valid acc \t Average test accuracy \t Ensemble test accuracy")
for config_name, crf_on, total_valid_accs, valid_avrf_accs, total_test_accs, total_avrg_accs in meta_data:
    print("%s-%d \t %f \t\t %f \t\t %f \t\t %f" % (config_name, crf_on, total_valid_accs, valid_avrf_accs, total_test_accs, total_avrg_accs))

# Below is not supported yet

In [None]:
import itertools

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

In [None]:
print("validation")
#np.save("total_valid_outs.npy", total_valid_outs)
#np.save("total_valid_targets.npy", total_valid_targets)
plt.figure()
plt.plot(fpr_valid, tpr_valid)
plt.show()
#print(total_outs[-10:, 0])


cnf_matrix = confusion_matrix(total_valid_targets, np.argmax(total_valid_outs, axis=1))
plt.figure()
plot_confusion_matrix(cnf_matrix, ["adherent", "non-adherent"])

In [None]:
print("test")
#np.save("total_test_outs.npy", total_test_outs)
#np.save("total_test_targets.npy", total_test_targets)
plt.figure()
plt.plot(fpr_test, tpr_test)
plt.show()


cnf_matrix = confusion_matrix(total_test_targets, np.argmax(total_test_outs, axis=1))
plt.figure()
cnf_matrix = confusion_matrix(total_test_targets, np.argmax(total_test_outs, axis=1))
plt.figure()
plot_confusion_matrix(cnf_matrix, ["adherent", "non-adherent"])