In [25]:
import pandas as pd
import pickle
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.metrics import roc_auc_score, accuracy_score, recall_score, precision_score, roc_curve
import matplotlib.pyplot as plt
from joblib import load
from keras.models import load_model

Using TensorFlow backend.


In [26]:
# Load data and stopwords
test_data = pd.read_pickle('../data/test_data.pkl')

with open('../data/stopwords.pkl', 'rb') as f:
    stopwords = pickle.load(f)

In [27]:
svm = {
    'count_model': load('../results/models/Count-SVM_model.joblib'),
    'count_test': load('../data/Count-SVM_test_vec.joblib'),
    'tfidf_model': load('../results/models/TF-IDF-SVM_model.joblib'),
    'tfidf_test': load('../data/TF-IDF-SVM_test_vec.joblib')
}

rf = {
    'count_model': load('../results/models/Count-RF_model.joblib'),
    'count_test': load('../data/Count-RF_test_vec.joblib'),
    'tfidf_model': load('../results/models/TF-IDF-RF_model.joblib'),
    'tfidf_test': load('../data/TF-IDF-RF_test_vec.joblib')
}

nn = {
    'count_model': load_model('../results/models/Count-NN_model.h5'),
    'count_test': load('../data/Count-NN_test_vec.joblib'),
    'tfidf_model': load_model('../results/models/TF-IDF-NN_model.h5'),
    'tfidf_test': load('../data/TF-IDF-NN_test_vec.joblib')
}

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where



In [30]:
class CompareModels(object):
    def __init__(self, test_data, svm_dict, rf_dict, nn_dict):
        self.test_data = test_data
        self.svm_params = svm_dict
        self.rf_params = rf_dict
        self.nn_params = nn_dict
        self.scores = pd.DataFrame()
    
    def plot_roc_curves(self):
        fpr_svm_count, tpr_svm_count = self.get_fpr_tpr(self.svm_params['count_model'], self.svm_params['count_test'])
        fpr_svm_tfidf, tpr_svm_tfidf = self.get_fpr_tpr(self.svm_params['tfidf_model'], self.svm_params['tfidf_test'])
        
        fpr_rf_count, tpr_rf_count = self.get_fpr_tpr(self.rf_params['count_model'], self.rf_params['count_test'])
        fpr_rf_tfidf, tpr_rf_tfidf = self.get_fpr_tpr(self.rf_params['tfidf_model'], self.rf_params['tfidf_test'])
        
        fpr_nn_count, tpr_nn_count = self.get_fpr_tpr_keras(self.nn_params['count_model'], 
                                                            self.nn_params['count_test'])
        fpr_nn_tfidf, tpr_nn_tfidf = self.get_fpr_tpr_keras(self.nn_params['tfidf_model'], 
                                                            self.nn_params['tfidf_test'])
    
        plt.figure(figsize=(12, 10))
        plt.plot([0, 1], [0, 1], linestyle='--', color='#D3D3D3')
        plt.plot(fpr_svm_count, tpr_svm_count, color='#fadbd8', label='Count-SVM') # light red
        plt.plot(fpr_svm_tfidf, tpr_svm_tfidf, color='#943126', label='TF-IDF-SVM') # dark red
        plt.plot(fpr_rf_count, tpr_rf_count, color='#ebdef0', label='Count-RF') # light purple
        plt.plot(fpr_rf_tfidf, tpr_rf_tfidf, color='#512e5f', label='TF-IDF-RF') # dark purple
        plt.plot(fpr_nn_count, tpr_nn_count, color='#d6eaf8', label='Count-NN') # light blue
        plt.plot(fpr_nn_tfidf, tpr_nn_tfidf, color='#21618c', label='TF-IDF-NN') # dark blue
        plt.title('Receiver Operating Characteristic Curves', fontsize=20)
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.legend()
        plt.savefig('../results/images/all_ROC_curve.png', bbox_inches='tight')
        plt.close()
        return None
    
    def get_fpr_tpr(self, clf, X_vec):
        y_test = self.test_data['label'].values
        y_prob = clf.predict_proba(X_vec)
        y_prob = y_prob[:, 1]
        fpr, tpr, thresholds = roc_curve(y_test, y_prob)
        return fpr, tpr
    
    def get_fpr_tpr_keras(self, model, X_vec):
        y_test = self.test_data['label'].values
        y_prob = model.predict(X_vec).flatten()

        fpr, tpr, thresholds = roc_curve(y_test, y_prob)
        return fpr, tpr
        

In [31]:
compare_models = CompareModels(test_data, svm, rf, nn)
compare_models.plot_roc_curves()