# Evaluator

In [1]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm
from functools import reduce
from scipy.interpolate import splev, splprep
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.metrics import confusion_matrix, auc
#from matplotlib.patches import ConnectionPatch
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

In [38]:
class Evaluator:
    
    def __init__(self, lw=1, **kwgs):
        self.lw = lw
        self.score_names = ['wam_scores', 'bn_scores', 'svm_scores']
        self.Multi_scores = {key: kwgs[key] for key in self.score_names}
        self.labels = kwgs['labels']
        self.Conf_mas = []
        self.colors = ListedColormap(sns.color_palette("husl", 4))
        colors = self.colors.colors
        self.cmap = {name: color for name, color in zip(self.score_names+['L'], colors)}
        makers = ['d','o','*',':']
        self.makermap = {name: lt for name, lt in zip(self.score_names+['L'], makers)}
        
    def PR_Curves(self, fig, ax, suffix, T_range=np.arange(0, 10, 0.5), 
                  xlim=[0.0, 1.05], ylim=[0.0, 1.05], legend=True, sub=False):
        if self.Conf_mas == []:
            Conf_tbs = self.Confusion_table(T_range)
        lw = self.lw
        #plt.figure(figsize=(5,5), dpi=120)
        if sub:
            axins = ax.inset_axes([0.3, 0.3, 0.4, 0.4])
        for name, tb in Conf_tbs.items():
            table = self.Cal_Rc_Pr_F1S(tb)
            
            Rc = table['Recall']
            Pr = table['Precision']
            pr_auc = auc(Rc,Pr)
            
            '''table = table.drop_duplicates(['Recall']).sort_values(by=['Recall']).reset_index(drop=True)
            #print(Rc.shape)
            if Rc.shape[0] >= 5:
                tck, u = splprep([Rc, Pr], s=0)
                new_points = splev(u, tck)
                Pr = new_points[1]
                Rc = new_points[0]'''

            Name = name.rstrip('_scores').upper()#.ljust(5, ' ')
            
            ax.plot(Rc, Pr, # marker=self.makermap[name],markersize=4, lw=lw, 
                     label='PR curve for {} (area = {:.2})'.format(Name, pr_auc)#, color=self.cmap[name]
                   ) 
            if sub:
                axins.plot(Rc, Pr, #color=self.cmap[name], lw=lw,marker=self.makermap[name], markersize=4, 
                       ##markerfacecolor='C0',markeredgecolor='black'
                      )
        if sub:
            axins.set_xlim(0.9, 1)
            axins.set_ylim(0.9, 1)
            mark_inset(ax, axins, loc1=4, loc2=2)
        #plt.plot([1, 0], [0, 1], color=self.cmap['L'], lw=lw*0.6, linestyle='--')
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.set_xlabel('Recall')
        ax.set_ylabel('Precision')
        ax.set_title('Precision-Recall curve')
        ax.legend(loc="lower left", fontsize=8)
        plt.savefig('../Figures/PRC_{}.png'.format(suffix), bbox_inches='tight', dpi=fig.dpi, pad_inches=0.2)
        plt.close()
        #plt.show()

    def ROC_Curves(self, fig, ax, suffix, T_range=np.arange(0, 10, 0.5), 
                   xlim=[0.0, 1.05], ylim=[0.0, 1.05], legend=True):
        if self.Conf_mas == []:
            Conf_tbs = self.Confusion_table(T_range)
        lw = self.lw
        #plt.figure(figsize=(5,5), dpi=120)
        axins = ax.inset_axes([0.3, 0.3, 0.4, 0.4])
        for name, tb in Conf_tbs.items():
            table = self.Cal_Tpr_Fpr(tb)
            Tpr = table['Tpr']
            Fpr = table['Fpr']
            roc_auc = auc(Fpr,Tpr)
            
            '''table = table.drop_duplicates(['Fpr']).sort_values(by=['Fpr']).reset_index(drop=True)
            #print(Tpr.shape)
            Tpr = table['Tpr']
            Fpr = table['Fpr']
            roc_auc = auc(Fpr,Tpr)
            if Tpr.shape[0] >= 5:
                tck, u = splprep([Fpr, Tpr], s=0)
                new_points = splev(u, tck)
                Tpr = new_points[1]
                Fpr = new_points[0]'''
            # use plt to plot ROC

            Name = name.rstrip('_scores').upper()#.ljust(5, ' ')
            
            ax.plot(Fpr, Tpr, #color=self.cmap[name], marker=self.makermap[name], markersize=4, lw=lw, 
                    label='ROC curve for {} (area = {:.2})'.format(Name, roc_auc)) 
            #for y, x in zip(Tpr, Fpr): plt.text(x, y+0.001, '%.2f' % y, ha='center', va= 'bottom',fontsize=9)
            
            axins.plot(Fpr, Tpr, lw=lw,
                       #color=self.cmap[name], marker=self.makermap[name], markersize=4, 
                       #markerfacecolor='C0',markeredgecolor='black'
                      )
        axins.set_xlim(0, 0.1)
        axins.set_ylim(0.9, 1)
        mark_inset(ax, axins, loc1=3, loc2=1)
        
        #plt.plot([0, 1], [0, 1], color=self.cmap['L'], lw=lw*0.6, linestyle='--')
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.set_xlabel('False Positive Rate (1-Specificity)')
        ax.set_ylabel('True Positive Rate (Sensitivity)')
        ax.set_title('Receiver operating characteristic')
        if legend:
            ax.legend(loc="lower right", fontsize=8)
        #plt.show()
        plt.savefig('../Figures/ROC_{}.png'.format(suffix), bbox_inches='tight', 
                    dpi=fig.dpi, pad_inches=0.2)
        plt.close()
        
    def F1_t_Curves(self, fig, ax, suffix, T_range=np.arange(0, 10, 0.5), 
                    xlim=[0.0, 1.05], ylim=[0.0, 1.05]):
        if self.Conf_mas == []:
            Conf_tbs = self.Confusion_table(T_range)
        lw = self.lw
        #cmap = {name.rstrip('_scores').upper: color for name, color in self.cmap.items()}
        #plt.figure(figsize=(5,5), dpi=120)
        #axins = ax.inset_axes([0.3, 0.3, 0.4, 0.4])
        F1_t = [self.Cal_Rc_Pr_F1S(tb)[['Threshold', 'F1-score']].rename(columns=lambda x: 
                                                                         name.split('_')[0].upper() 
                                              if x != 'Threshold' else x) for name, tb in Conf_tbs.items()]
        df = reduce(lambda left,right: pd.merge(left,right,on='Threshold'), F1_t)

        df = df[(df['Threshold'] > xlim[0]) & (df['Threshold'] < xlim[1])]
        df['Threshold'] = df['Threshold'].apply(lambda x: '%.1f' % x)
        #print(cmap)
        # plt.grid(which='both')
        ax.set_xlabel('Threshold')
        ax.set_ylabel('F1-score')
        ax.set_title('F1-score-Threshold')
        df.plot(y=['WAM','BN','SVM'], x='Threshold', kind='bar', #colormap=self.colors, 
                grid=True, yticks=np.arange(0, 1, 0.1), #figsize=(10,10),
                #title="F1-score-Threshold", xlabel=
                 ax=ax, sharex=True, sharey=True)
        plt.savefig('../Figures/F1_t_{}.png'.format(suffix), bbox_inches='tight',
                    dpi=fig.dpi, pad_inches=0.2)
        plt.close()
        
    def Confusion_table(self, T_range):
        y_true = self.labels
        mScores = self.Multi_scores
        names = self.score_names
        ys_pred = { name: pd.DataFrame({ T: (mScores[name] > T).astype(np.int) for T in T_range }) 
                   for name in names }
        stats_Multi_T = { name: pd.DataFrame({'Threshold': [T for T in T_range]}) for name in names }
        for n in names:
            # tn, fp, fn, tp = confusion_matrix([0, 1, 0, 1], [1, 1, 1, 0]).ravel()
            stats_tmp = stats_Multi_T[n]['Threshold'
                                        ].apply( lambda T: 
                                                confusion_matrix( y_true, 
                                                                 ys_pred[n][T], 
                                                                 labels=[0,1] ).ravel())
            stats_tmp = [stat for stat in zip(*stats_tmp)]

            stats_Multi_T[n]['TN'] = stats_tmp[0]
            stats_Multi_T[n]['FP'] = stats_tmp[1]
            stats_Multi_T[n]['FN'] = stats_tmp[2]
            stats_Multi_T[n]['TP'] = stats_tmp[3]
        return stats_Multi_T
        
    def Cal_Sn_Sp(self, conf_tb):
        conf_tb['Sensitivity'] = conf_tb['TP'] / (conf_tb['TP']+conf_tb['FN'])
        conf_tb['Specificity'] = conf_tb['TN'] / (conf_tb['TN']+conf_tb['FP'])
        conf_tb = conf_tb.fillna(1)
        return conf_tb
    
    def Cal_Tpr_Fpr(self, conf_tb):
        conf_tb['Tpr'] = conf_tb['TP'] / (conf_tb['TP']+conf_tb['FN'])
        conf_tb['Fpr'] = conf_tb['FP'] / (conf_tb['TN']+conf_tb['FP'])
        conf_tb = conf_tb.fillna(1)
        return conf_tb
    
    def Cal_Rc_Pr_F1S(self, conf_tb):
        conf_tb['Recall'] = conf_tb['TP'] / (conf_tb['TP']+conf_tb['FN'])
        conf_tb['Precision'] = conf_tb['TP'] / (conf_tb['TP']+conf_tb['FP'])
        conf_tb = conf_tb.fillna(1)
        conf_tb['F1-score'] = ( 2 * conf_tb['Precision'] * conf_tb['Recall'] / 
                               (conf_tb['Precision']+conf_tb['Recall']) )
        # 2 * (precision * recall) / (precision + recall)
        return conf_tb
    
    def Save_figs(self, trange, suffix=''):
        fig, ax = plt.subplots(1,1, figsize=(8,4), dpi=200)
        self.F1_t_Curves(fig, ax, suffix, T_range=np.arange(-10,10,0.5), xlim=[-10,10])
        fig, ax = plt.subplots(1,1, figsize=(4,4), dpi=200)
        self.ROC_Curves(fig, ax, suffix, T_range=trange, xlim=[-0.05, 1.05])
        fig, ax = plt.subplots(1,1, figsize=(4,4), dpi=200)
        self.PR_Curves(fig, ax, suffix, T_range=trange, xlim=[-0.05, 1.05], 
                       sub=True)
        
    def Cal_metrics(self, trange):
        tb = self.Confusion_table(T_range=trange)
        for name in tb.keys():
            n = name.rstrip('_scores').upper()+'.'
            names = {'Threshold': 'T', 'Recall': n+'Rc', 'Precision': n+'Pr', 
                     'Sensitivity': n+'Sn', 'Specificity': n+'Sp', 
                     'F1-score': n+'F1', 'TP': n+'TP', 'TN': n+'TN', 
                     'FP': n+'FP', 'FN': n+'FN', 'Tpr': n+'Tpr', 'Fpr': n+'Fpr'}
            tb[name] = self.Cal_Rc_Pr_F1S(tb[name])
            tb[name] = self.Cal_Sn_Sp(tb[name])
            tb[name] = self.Cal_Tpr_Fpr(tb[name]).round(2).rename(columns=names)
        return reduce(lambda x,y: pd.merge(left=x, right=y, on=['T']), tb.values())

import matplotlib.image as pli
plt.ion()
plt.style.use(['science','ieee','no-latex'])

pred = np.arange(0, 1, 0.001)
y = (np.arange(0,1,0.001) > 0.4).astype(int)
eval = Evaluator(wam_scores=pred + np.random.normal(0.1, scale=0.1, size=(1000)), 
                 bn_scores=pred, 
                 svm_scores=pred + np.random.normal(-0.1, scale=0.1, size=(1000)),labels=y)
eval.Cal_metrics(np.arange(-1,2,0.1))

%matplotlib inline
import matplotlib.image as pli
plt.ion()
plt.style.use(['science','ieee','no-latex'])

pred = np.arange(0, 1, 0.001)
y = (np.arange(0,1,0.001) > 0.4).astype(int)
eval = Evaluator(wam_scores=pred + np.random.normal(0.1, scale=0.1, size=(1000)), 
                 bn_scores=pred, 
                 svm_scores=pred + np.random.normal(-0.1, scale=0.1, size=(1000)),labels=y)
eval.Save_figs(np.arange(-1,2,0.1), suffix='1')
figs = [pli.imread(i) for i in 
        ['../Figures/F1_t_1.png', '../Figures/ROC_1.png', '../Figures/PRC_1.png']]

def no_box(ax):
    ax.set_xticks([])
    ax.set_yticks([])
    ax.axis('off')
    for key, spine in ax.spines.items():
    # 'left', 'right', 'bottom', 'top'
        spine.set_visible(False)

plt.figure(figsize=(12,12), dpi=200)
plt.subplots_adjust(left=0.0,bottom=0.0,top=1,right=1)
ax1 = plt.subplot(2,2,1)
no_box(ax1)
plt.imshow(figs[2])
ax2 = plt.subplot(2,2,2)
no_box(ax2)
plt.imshow(figs[1])
ax3 = plt.subplot(2,1,2)
no_box(ax3)
plt.imshow(figs[0])
#plt.subplots_adjust(left=0.0,bottom=0.0,top=1,right=1)
plt.tight_layout() 
#fig.subplots_adjust(left=0.3, right=0.7, bottom=0.3, top=0.7)
#fig.show()