In [42]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
sns.set_style('whitegrid')
sns.set_palette('Set2')
plt.rcParams.update({'figure.titlesize': 'larger', 'legend.fontsize': 15.0})

## ROC Curve

In [None]:
def plot_ROC(fpr, tpr, auc, title = 'ROC curve'):
    sns.set_style('whitegrid')
    plt.figure(figsize=(8,6))
    plt.plot([0, 1], [0, 1], 'k--')
    plt.plot(fpr, tpr, label='AUC = %0.3f'.format(auc))
    plt.xlabel('False positive rate')
    plt.ylabel('True positive rate')
    plt.title(title)
    plt.legend(loc='best')
    plt.show()
    pass

## Precision-Recall Curve

In [44]:
def plot_PRC(precision, recall, ap, title= 'Precision-Recall Curve'):
    sns.set_style('whitegrid')
    
    plt.figure()
    plt.plot(recall, precision, lw=2, label='AP = %0.4f' % ap )
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(title)
    plt.legend(loc="best")
    return 

## Feature Importance Box Plot

In [45]:
def feature_importance_bar(feat_names, feat_importance, figsize=(5, 20), title='Feature Importance Plot'):
    from operator import itemgetter
    a = [list(x) for x in zip(*sorted(zip(feat_names, feat_importance), key=itemgetter(1)))]
    
    plt.figure(figsize=figsize)
    plt.barh(a[0],a[1])
    plt.xlabel('Importance')
    plt.ylabel('Feature Name')
    plt.title(title)
    return

## Confusion Matrix

In [46]:
def plot_cm(y_true, y_pred, title='Confusion Matrix', cmap=plt.cm.Blues):
    sns.set_style('white')
    
    cm = metrics.confusion_matrix(y_test, y_pred)
    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] # normarlize 
    
    from sklearn.utils.multiclass import unique_labels
    classes = unique_labels(y_true, y_pred)
    
    fig, ax = plt.subplots()
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax)
    # We want to show all ticks...
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           # ... and label them with the respective list entries
           xticklabels=classes, yticklabels=classes,
           title=title,
           ylabel='True label',
           xlabel='Predicted label')

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")
    # Loop over data dimensions and create text annotations.
    fmt = '.2f' 
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()
    