### A notebook containing the majority of analysis code used to not only investigate model outputs, but to also generate the images found in the report. 

In [None]:
import matplotlib.pyplot as plt
import mglearn.tools
import numpy as np
import scipy.stats as stats
import sklearn.metrics as skm

from SERA import SERA

In [None]:
def metrics(file):
    """calculates the predicted range, true range, pearson correlation 
    coefficient, and the spearman rank correlation coefficient from a 
    PointVS output file"""
    
    true, pred = [], []
    with open(file) as f:
        for line in f.readlines():
            true.append(float(line[0:6].strip()))
            pred.append(float(line[8:14].strip()))

    print ('pred_range:', np.max(pred) - np.min(pred))
    print ('true_range:', np.max(true) - np.min(true))
    print ('PCC:', np.corrcoef(true, pred)[0][1])
    print ('SPCC', stats.spearmanr(true, pred))

In [None]:
def distribution(file):
    '''Plots true and predicted distributions from pointvs outputs file.
    The bins of the histogram were defined dependending on the image being 
    produced for the report.'''
    
    true, pred = [], []
    with open(file) as f:
        for line in f.readlines():
            true.append(float(line[0:6].strip()))
            pred.append(float(line[8:14].strip()))

    
    #bins=np.histogram(np.hstack((true,pred)), bins=46)[1] #get the bin edges
    bins=[i for i in np.arange(-3, 15.0, 0.3)]
    fig, ax1 = plt.subplots()
    ax1.hist(true, bins, color='red', alpha=0.8)
    ax1.set_ylim(0, 260)
    ax1.hist(pred, bins, color='blue', alpha=0.5)
    ax1.set_xlim(-3, 15.4)
    ax1.set_xlabel('pKd')
    ax1.set_ylabel('Frequency')

In [None]:
def dist_et_interpol(file, control):
    '''Plots true and predicted distributions from pointvs outputs file,
    as well as the interpolation relevance curve generated from the supplied
    control points. The bins of the histogram were defined dependending on the 
    image being produced for the report.'''

    true, pred = [], []
    with open(file) as f:
        for line in f.readlines():
            true.append(float(line[0:6].strip()))
            pred.append(float(line[8:14].strip()))

    bounds = SERA._relevance_interval(true)
    y_true = SERA.filter_outliers(true, true, bounds)[0]

    x, y, control_set, relevance = SERA.interpolator(y_true, bounds, control)[1:]

    fig, ax1 = plt.subplots()
    ax2 = ax1.twinx()
    #bins=np.histogram(np.hstack((true,pred)), bins=40)[1]
     #get the bin edges
    bins=[i for i in np.arange(-3, 15.0, 0.3)]
    ax1.hist(true, bins, color='red', alpha=0.8)
    ax1.hist(pred, bins, color='blue', alpha=0.5)
    ax2.plot(x, y)
    ax2.plot(control_set, relevance, "o")
    ax1.set_xlabel('pKd')
    ax1.set_ylabel('Frequency')
    ax2.set_ylabel('Relevance')
    ax2.set_ylim(-0.050, 1.05)
    ax1.set_ylim(0, 260)
    ax1.set_xlim(-3, 15.4)

In [None]:
def PCC_comparison(file_dict):
    '''
    Plots the pearson correlation coefficient for each epoch
    output file, which must be passed in as a list in a dictionary with the 
    model names as keys.
    
    e.g) PCC_comparison({"SERA_1_0_1":[epoch_1_output.txt, epoch_2_output.txt]})
    
    '''
    pcc_dict = {}
    for k, model in file_dict.items():
        PCC = []
        for epoch in model:
            true, pred = [], []
            with open(epoch) as f:
                for line in f.readlines():
                    true.append(float(line[0:6].strip()))
                    pred.append(float(line[8:14].strip()))
                PCC.append(np.corrcoef(true, pred)[0][1])
        pcc_dict[k] = PCC

    fig, ax1 = plt.subplots()
    for k, v in pcc_dict.items():
        ax1.plot([i for i in range(2, len(v)+2)], v, label=k)
    ax1.legend()
    ax1.set_ylim(0.2, 0.5)
    ax1.set_xlim(2, 10)
    ax1.set_xlabel('epoch')
    ax1.set_ylabel('Pearson Correlation Coefficient')
    plt.grid(axis='y', alpha=0.5)
    plt.show()


In [None]:
def rho_comparison(file_dict):
    '''
    Plots the Spearman Rank Correlation Coefficient for each epoch
    output file, which must be passed in as a list in a dictionary with the 
    model names as keys.
    
    e.g) rho_comparison({"SERA_1_0_1":[epoch_1_output.txt, epoch_2_output.txt]})
    
    '''
    rho_dict = {}
    for k, model in file_dict.items():
        rho = []
        for epoch in model:
            true, pred = [], []
            with open(epoch) as f:
                for line in f.readlines():
                    true.append(float(line[0:6].strip()))
                    pred.append(float(line[8:14].strip()))
                rho.append(stats.spearmanr(true, pred)[0])
        rho_dict[k] = rho


    fig, ax1 = plt.subplots()
    for k, v in rho_dict.items():
        ax1.plot([i for i in range(2, len(v)+2)], v, label=k)
    ax1.legend()
    ax1.set_ylim(0.2, 0.5)
    ax1.set_xlim(2, 10)
    ax1.set_xlabel('epoch')
    ax1.set_ylabel('Spearman Rank Correlation Coefficient')
    plt.grid(axis='y', alpha=0.5)
    plt.show()

In [None]:
def active_metrics(file, thresh):
    '''
    Calculates accuracy, precision, sensitivity, false positive rate, 
    and specificity when a binary constraint is imposed as per a user defined
    threshold. Compounds above this threshold are of the 'positive' class, while 
    compounds below the thresholds are 'negative'.
    ''' 
      
    true, pred = [], []
    with open(file) as f:
        for line in f.readlines():
            true.append(float(line[0:6].strip()))
            pred.append(float(line[8:14].strip()))

    TP, TN, FN, FP = [], [], [], []

    for i in range(len(true)):
        if true[i] >= thresh and pred[i] >= thresh:
            TP.append(pred[i])
        elif true[i] < thresh and pred[i] < thresh:
            TN.append(pred[i])
        elif true[i] >= thresh and pred[i] < thresh:
            FN.append(pred[i])
        elif true[i] < thresh and pred[i] >= thresh:
            FP.append(pred[i])

    accuracy = (len(TP) + len(TN))/(len(TP) + len(TN) + len(FN) + len(FP))
    precision = len(TP) /(len(TP) + len(FP))
    sensitivity = len(TP)/(len(TP) + len(FN))
    FPR = len(FP)/(len(FP) + len(TN))
    specificity = 1 - FPR

    return ({'accuracy':accuracy, 'precision':precision, 'sensitivity':sensitivity, 'specificity':specificity, 'FPR':FPR})


In [None]:

class multiclass_metrics:
    '''A class to divide the true and predicted values of the PointVS output
    into multiclass classification style bins (as defined within the instantiation method).
    A normlaised confusion matrix of these labels can be generated, as well as the spearman rank
    correlation coefficient of the correctly predicted values in each bin.
    
    Bin metrics e.g) mutliclass_metrics('output_file.txt').bin_metrics()
    Confusion matrix e.g) multiclass_metrics('output_file.txt').generate_confusion()
    
    '''

    def __init__(self, file):

        self.bins = {
            "-(4-3)": (-4, -3),
            "-(3-2)": (-3, -2),
            "-(2-1)": (-2, -1),
            "-(1, 0)": (-1, 0),
            "0-1": (0, 1),
            "1-2": (1, 2),
            "2-3": (2, 3),
            "3-4": (3, 4),
            "4-5": (4, 5),
            "5-6": (5, 6),
            "6-7": (6, 7),
            "7-8": (7, 8),
            "8-9": (8, 9),
            "9-10": (9, 10),
            "10-11": (10, 11),
            "11-12": (11, 12),
            "12-13": (12, 13),
            "13-14": (13, 14),
            "14-15": (14, 15),
            "15-16": (15, 16),
        }

        self.labels = [i for i in self.bins.keys()]

        self.true, self.pred = self._extract_vals(file)

        self.true, self.pred, self.map_bins_true, self.map_bins_pred = self.populate_bins()

    def bin_metrics(self):
        '''calculate the spearman rank correlation coefficient of each label'''

        spcorr = {}
        for k in self.bins.keys():
            t, p = [], []
            for i in range(len(self.map_bins_true)):
                if list(self.map_bins_true[i].keys())[0] == k:
                    if list(self.map_bins_pred[i].keys())[0] == k:
                        t.append(list(self.map_bins_true[i].values())[0])
                        p.append(list(self.map_bins_pred[i].values())[0])

            spcorr[k] = stats.spearmanr(t, p)[0]
        
        return spcorr

    def populate_bins(self):
        '''divide the true and predicted values into mutliclass labels'''

        bin_map_true, bin_map_pred = [], []
        true_classes, pred_classes = [], []
        for i in range(len(self.true)):
            for label, bin in self.bins.items():
                if bin[0] <= self.true[i] < bin[1]:
                    true_classes.append(label)
                    bin_map_true.append({label:self.true[i]})
                if bin[0] <= self.pred[i] < bin[1]:
                    pred_classes.append(label)
                    bin_map_pred.append({label:self.pred[i]})

        true_sorted, pred_sorted = [], []
        for i in self.labels:
            for j in range(0, len(true_classes)):
                if true_classes[j] == i:
                    true_sorted.append(true_classes[j])
                    pred_sorted.append(pred_classes[j])

        return true_sorted, pred_sorted, bin_map_true, bin_map_pred

    def generate_confusion(self):
        '''generate a normalised confusion matrix heatmap for the defined labels'''

        matrix = skm.confusion_matrix(self.true, self.pred, labels=self.labels, normalize='true')

        plt.figure(figsize=(15, 15))
        matrix_image = mglearn.tools.heatmap(
            matrix*100,
            xlabel="predicted label",
            ylabel="true label",
            xticklabels=self.labels,
            yticklabels=self.labels,
            cmap=plt.cm.gray_r,
            fmt="%d",
        )
        plt.title("confusion matrix")
        plt.gca().invert_yaxis()

    @staticmethod
    def _extract_vals(file):
        '''Extract true and predicted values from the PointVS output file'''
        true, pred = [], []
        with open(file) as f:
            for line in f.readlines():
                true.append(float(line[0:6].strip()))
                pred.append(float(line[8:14].strip()))

        return true, pred