In [1]:
from sklearn.metrics import recall_score
from sklearn.metrics import roc_curve
import matplotlib.pyplot as plt
from sklearn import tree


In [2]:
class Visualize:
    def __init__(self, path, prefix):
        self.path = path
        self.prefix = prefix
        
        
        
    def setpath(self, path):
        self.path = path
    
        
    def roc_curve_plt(self, y_test, y_pred):


        fpr, tpr, auc_thresholds = roc_curve(y_test, y_pred)

        plt.figure(figsize=(8,8))
        plt.title('ROC Curve')
        plt.plot(fpr, tpr, linewidth=2)
        plt.plot([0, 1], [0, 1], 'k--')
        plt.axis([-0.005, 1, 0, 1.005])
        plt.xticks(np.arange(0,1, 0.05), rotation=90)
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate (Recall)")
        plt.savefig(f'{self.path}\\{self.prefix}-roc.png')
        
    
        
    def ccp_plt(self, clfs, alphas, test):

        #Plots recall/alpha curve 
        #Parameters: list of classifiers, list of ccp_alpha values, [X_test, y_test]



        scores = [recall_score(test[1], clf.predict(test[0])) for clf in clfs]

        
        plt.figure(figsize = (10, 6))
        plt.title('ccp Alpha')
        plt.grid()
        plt.plot(alphas[:-1], scores)
        plt.xlabel("Alpha Score")
        plt.ylabel("Recall Score")
        plt.savefig(f'{self.path}\\{self.prefix}-ccp.png')
        
    def sen_spe_plt(self, sen, spe, thresh):

        plt.figure(figsize = (8, 8))
        plt.title('Sensitivity vs Specificity')
        plt.plot(thresh, sen, 'b-', label = 'Sensitivity')
        plt.plot(thresh, spe, 'g-', label = 'Specificity')
        plt.ylabel("Score")
        plt.xlabel("Threshold")
        plt.legend(loc = 'best')
        plt.savefig(f'{self.path}\\{self.prefix}-sen-spe.png')
        
    def feature_sig_plt(self, clf, db, feat_len):
        #Plots a bar graph of most significant features
        #parameters: classifier, top n features of classifier
        
        #should probably add a variable for variance allowance
        
        
        feat = list(clf.feature_importances_)
        sorted_feat = sorted(feat, reverse = True)
        
        
        best_feats = []
        best_values = []
        for i in range(feat_len):
            best_values.append(sorted_feat[i])
            best_feats.append(db.columns[feat.index(sorted_feat[i])])
        
        plt.figure(figsize = (16, 8))
        plt.title(f'{feat_len} most significant features')
        plt.bar(best_feats, best_values)
        plt.xlabel("Features")
        plt.ylabel("")
        plt.savefig(f'{self.path}\\{self.prefix}-feat-sig.png')
        
        
    def tree_plt(self, clf, db):
        feats = db.columns[: -1]
        classes = db['class'].map({1:'T', 0: 'F'})
        
        
        _, ax = plt.subplots(figsize = (140, 100))
        tree.plot_tree(clf, filled = True, ax = ax, feature_names = feats, class_names = classes)
        plt.savefig(f'{self.path}\\{self.prefix}-tree.png')
        plt.show()