In [1]:
import pickle
import numpy as np
np.set_printoptions(precision=4)

import pandas as pd
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import confusion_matrix

In [2]:
sub_to_pred = pickle.load(open( "sub_to_pred.pkl", "rb" ))
sub_to_true = pickle.load(open( "sub_to_true.pkl", "rb" ))

In [3]:
# pick a threshold, with real quantity number, based on fpr & tpr lists

def pick_threshold(fpr, tpr, thresholds, MinMax_transformer):
    
    # from auroc, get the point with biggest area
    #idx_max = np.argmax(((1-fpr) * tpr))
    
    auc_value_list = (1-fpr) * tpr
    max_auc_value = np.amax(auc_value_list)
    
    idx_max = get_allindex(auc_value_list, max_auc_value)
                        
    # inverse back to real quantity
    picked_threshold = MinMax_transformer.inverse_transform(thresholds[idx_max].reshape(-1, 1))
    
    return picked_threshold

In [4]:
#get all indexes from a list for an element

def get_allindex(Alist, element):

    return [i for i, j in enumerate(Alist) if j == element]

In [5]:
def print_confusion_matrix(sub_to_true_ternary, sub_to_pred_ternary, all_sub_list):

    # consolidated results, on top of each sub
    all_CHO_true_list = []
    all_CHO_pred_list = []

    all_pro_true_list = []
    all_pro_pred_list = []

    all_fat_true_list = []
    all_fat_pred_list = []


    for sub in all_sub_list:
        print(sub)
        print()

        print('CHO:')
        print(confusion_matrix(sub_to_true_ternary[sub][:,0].reshape(-1,), sub_to_pred_ternary[sub][:,0].reshape(-1,), labels=[0,1,2]))
        all_CHO_true_list.extend(sub_to_true_ternary[sub][:,0].reshape(-1,).tolist())
        all_CHO_pred_list.extend(sub_to_pred_ternary[sub][:,0].reshape(-1,).tolist())


        print('pro:')
        print(confusion_matrix(sub_to_true_ternary[sub][:,1].reshape(-1,), sub_to_pred_ternary[sub][:,1].reshape(-1,), labels=[0,1,2]))
        all_pro_true_list.extend(sub_to_true_ternary[sub][:,1].reshape(-1,).tolist())
        all_pro_pred_list.extend(sub_to_pred_ternary[sub][:,1].reshape(-1,).tolist())


        print('fat:')
        print(confusion_matrix(sub_to_true_ternary[sub][:,2].reshape(-1,), sub_to_pred_ternary[sub][:,2].reshape(-1,), labels=[0,1,2]))
        all_fat_true_list.extend(sub_to_true_ternary[sub][:,2].reshape(-1,).tolist())
        all_fat_pred_list.extend(sub_to_pred_ternary[sub][:,2].reshape(-1,).tolist())


        print('='*100)
       
    print('CHO:')
    print(confusion_matrix(all_CHO_true_list,all_CHO_pred_list))
    
    print('pro:')
    print(confusion_matrix(all_pro_true_list,all_pro_pred_list))
    
    print('fat:')
    print(confusion_matrix(all_fat_true_list,all_fat_pred_list))

In [6]:
# sub_to_true => sub_to_true_ternary

def convert_true_NumberToClass(sub_to_true, all_sub_list):
    
    sub_to_true_ternary = {}
    
    for sub in all_sub_list:

        if sub not in sub_to_true_ternary:
            
            sub_to_true_ternary[sub] = []
        
        # CHO
        sub_to_true[sub][sub_to_true[sub] == 52.25] = 0
        sub_to_true[sub][sub_to_true[sub] == 94.75] = 1
        sub_to_true[sub][sub_to_true[sub] == 179.75] = 2
        
        
        # pro
        sub_to_true[sub][sub_to_true[sub] == 15] = 0
        sub_to_true[sub][sub_to_true[sub] == 30] = 1
        sub_to_true[sub][sub_to_true[sub] == 60] = 2
        
        
        # fat
        sub_to_true[sub][sub_to_true[sub] == 13] = 0
        sub_to_true[sub][sub_to_true[sub] == 26] = 1
        sub_to_true[sub][sub_to_true[sub] == 52] = 2
    
        sub_to_true[sub] = sub_to_true[sub].astype(int)
        
    return sub_to_true

In [7]:
# get each sub: fpr, tpr and the picked thresholds lists

# only show figure when told so
plt.ioff()


def AUROCplot_fprtpr_pickedThresholds(sub_to_true, sub_to_pred, all_sub_list, all_comp_list, binary_list, className_number_dict):
    
    # to have this dict to plot AUROC (x-axis: fpr, y-axis: tpr)
    sub_comp_3binariesClass_fprtpr_dict = {} 
    
    # thresholds for each sub, comp, binary comparisons (two binary classes)
    sub_comp_3binariesClass_thresholds_dict = {}
    
    for sub in all_sub_list:

        if sub not in sub_comp_3binariesClass_fprtpr_dict:
            sub_comp_3binariesClass_fprtpr_dict[sub] = {}
        if sub not in sub_comp_3binariesClass_thresholds_dict:
            sub_comp_3binariesClass_thresholds_dict[sub] = {}
        
        
        for comp in all_comp_list:

            if comp not in sub_comp_3binariesClass_fprtpr_dict[sub]:
                sub_comp_3binariesClass_fprtpr_dict[sub][comp] = {}
            if comp not in sub_comp_3binariesClass_thresholds_dict[sub]:
                sub_comp_3binariesClass_thresholds_dict[sub][comp] = {}
                
                
            #each sub and comp has a new figure
            fig = plt.figure()
            plt.title(sub + ' ' + comp)


            for binary_comparison in binary_list:

                if binary_comparison not in sub_comp_3binariesClass_fprtpr_dict[sub][comp]:
                    sub_comp_3binariesClass_fprtpr_dict[sub][comp][binary_comparison] = {'fpr':[], 'tpr':[], 'thresholds':[]}
                if binary_comparison not in sub_comp_3binariesClass_thresholds_dict[sub][comp]:
                    sub_comp_3binariesClass_thresholds_dict[sub][comp][binary_comparison] = []
                
                
                
                if comp == 'CHO':
                    comp_position = 0

                elif comp == 'pro':
                    comp_position = 1

                elif comp == 'fat':
                    comp_position = 2

                pred_list = np.asarray(sub_to_pred[sub])[:,comp_position]
                true_list = np.asarray(sub_to_true[sub])[:,comp_position]
                
                
                #### get ready proba and labels for sklearn.roc_curve ############
                ##################################################################
                # min-max pred_list w.r.t. that comp column only
                MinMax_transformer = MinMaxScaler()
                pred_MinMax_list = MinMax_transformer.fit_transform(pred_list.reshape(-1, 1))

                true_df = pd.get_dummies(true_list)

                # two real quantities of that comp
                Binaryclass_volume_list = className_number_dict[comp][binary_comparison]

                # indexes list for that two quantities in true_list
                idx_list = list(set(get_allindex(true_list, Binaryclass_volume_list[0]) + get_allindex(true_list, Binaryclass_volume_list[1])))

                # true list
                #get the label list from bigger class's list
                true_list = true_df[np.amax(np.asarray(Binaryclass_volume_list))].values[idx_list]
                 
                # proba list
                #get list from min-max pred_list
                prob_list = pred_MinMax_list[idx_list]
                ##################################################################
                
                
                
                fpr, tpr, thresholds = roc_curve(true_list, prob_list)
                picked_threshold = pick_threshold(fpr, tpr, thresholds, MinMax_transformer)

                
                #****** deal with multiple optimal thresholds **************
                if len(picked_threshold) > 1:
                    
                    #only deal with 'lowVSmid' and 'midVShigh' for now
                    if binary_comparison == 'lowVSmid':
                        picked_threshold = np.asarray([np.amin(picked_threshold)])
                    elif binary_comparison == 'midVShigh':
                        picked_threshold = np.asarray([np.amax(picked_threshold)])
                        
                    # 'lowVShigh' random select second from sorted list
                    elif binary_comparison == 'lowVShigh':
                           
                        #picked_threshold = np.asarray([np.random.choice(picked_threshold.reshape(-1,), 1)])
                        #picked_threshold = np.amin(picked_threshold)
                        #picked_threshold = np.amax(picked_threshold)
                        picked_threshold = np.asarray([np.sort(picked_threshold)[1]])
                picked_threshold = picked_threshold.reshape(1,1)
                #************************************************************        
                
                
                
                
                
                    
                    
                sub_comp_3binariesClass_fprtpr_dict[sub][comp][binary_comparison]['fpr'] = fpr
                sub_comp_3binariesClass_fprtpr_dict[sub][comp][binary_comparison]['tpr'] = tpr
                sub_comp_3binariesClass_fprtpr_dict[sub][comp][binary_comparison]['thresholds'] = thresholds

                sub_comp_3binariesClass_thresholds_dict[sub][comp][binary_comparison].append(picked_threshold)
                
                
                plt.plot([0, 1], [0, 1], 'k--', lw=2)
                plt.plot(fpr, tpr, 'o-', label=str(binary_comparison.split('VS')), alpha=0.3)
                plt.legend()

                
             
            #choose from showing the figures and saving the figures
            #plt.show()
            fig.savefig('AUROC_figures/'+sub + ' ' + comp + '.png')

        
        
        
    
                                    
                
    return (sub_comp_3binariesClass_fprtpr_dict, sub_comp_3binariesClass_thresholds_dict)

# (AUROC plot)   (fpr tpr)   (picked Thresholds)

In [8]:
all_sub_list = ['38A', '38B', '38C', '38D', '38E', '38F', '38H']
#all_sub_list= []

all_comp_list = ['CHO', 'pro', 'fat']

# all binary comparisons' names
binary_list = ['lowVSmid', 'midVShigh', 'lowVShigh']

className_number_dict = {'CHO':{'lowVSmid':[52.25, 94.75], 'midVShigh':[94.75, 179.75], 'lowVShigh':[52.25, 179.75]},
                         'pro':{'lowVSmid':[15, 30], 'midVShigh':[30, 60], 'lowVShigh':[15, 60]},
                         'fat':{'lowVSmid':[13, 26], 'midVShigh':[26, 52], 'lowVShigh':[13, 26]}}






sub_comp_3binariesClass_fprtpr_dict, sub_comp_3binariesClass_thresholds_dict = AUROCplot_fprtpr_pickedThresholds(sub_to_true, 
                                                                                                                  sub_to_pred, 
                                                                                                                  all_sub_list, 
                                                                                                                  all_comp_list, 
                                                                                                                  binary_list, 
                                                                                                                  className_number_dict)




# thresholding regression for classification

In [9]:
# use threshold to classify
# low: 0, mid: 1, high: 2
# return list of tenary classificatin results
# sub_to_pred + thresholds => sub_to_pred_ternary

def thresholding_rgrssn_for_clssftn(all_sub_list, all_comp_list, binary_list, sub_to_pred, sub_to_true_ternary, 
                                    sub_comp_3binariesClass_thresholds_dict):
    
    
    #strict on low and high class, loose on mid class
    
    sub_to_pred_ternary = {}
    
    for sub in all_sub_list:
        
        if sub not in sub_to_pred_ternary:
            
            sub_to_pred_ternary[sub] = []
        
        for comp in all_comp_list:
            
            
            if comp == 'CHO':
                    comp_position = 0

            elif comp == 'pro':
                comp_position = 1

            elif comp == 'fat':
                comp_position = 2
                    
                    
            pred_list = np.asarray(sub_to_pred[sub])[:,comp_position]
            true_list = sub_to_true[sub][:, comp_position]
                
                
            # 3 thresholds

            lowVSmid_threshold = sub_comp_3binariesClass_thresholds_dict[sub][comp]['lowVSmid'][0][0]
            midVShigh_threshold = sub_comp_3binariesClass_thresholds_dict[sub][comp]['midVShigh'][0][0]
            lowVShigh_threshold = sub_comp_3binariesClass_thresholds_dict[sub][comp]['lowVShigh'][0][0]
           
            
            three_thresholds = [lowVSmid_threshold.tolist(), midVShigh_threshold.tolist(), lowVShigh_threshold.tolist()]
            three_thresholds.sort()
            
            
            
            
            
            pred_list_ternary = []
            
            if lowVSmid_threshold <= lowVShigh_threshold <= midVShigh_threshold:
            
                for pred_item in pred_list:

                    if pred_item < lowVSmid_threshold:
                        
                        pred_item_ternary = 0
                    
                    elif pred_item > midVShigh_threshold:
                        
                        pred_item_ternary = 2
                        
                    else:
                        
                        pred_item_ternary = 1
                
                    pred_list_ternary.append(pred_item_ternary)
                
                sub_to_pred_ternary[sub].append(np.asarray(pred_list_ternary))
           
        
            elif lowVSmid_threshold <= midVShigh_threshold:
                
                for pred_item in pred_list:
                    
                    #if pred_item < lowVSmid_threshold:
                    if pred_item < three_thresholds[0]:
                        
                        pred_item_ternary = 0
                    
                    #elif pred_item > midVShigh_threshold:
                    elif pred_item >= three_thresholds[2]:
                        
                        pred_item_ternary = 2
                        
                    else:
                        
                        pred_item_ternary = 1
                
                    pred_list_ternary.append(pred_item_ternary)
                
                sub_to_pred_ternary[sub].append(np.asarray(pred_list_ternary))
            
                print()
                print('*'*50)
                print(sub, comp, sub_comp_3binariesClass_thresholds_dict[sub][comp])
                print('*'*50)
                print()
                
            # worse case, just let every below low threshold as low
            #anything else is mid
            else:
                 
                for pred_item in pred_list:
                    
                    #if pred_item < lowVSmid_threshold:
                    if pred_item > three_thresholds[2]:
                        
                        pred_item_ternary = 2
                    
                    #elif pred_item > lowVSmid_threshold:
                    elif pred_item <= three_thresholds[2]:
                    
                        pred_item_ternary = 1

                    pred_list_ternary.append(pred_item_ternary)
                    
                sub_to_pred_ternary[sub].append(np.asarray(pred_list_ternary))
                
                print()
                print('*'*50)
                print(sub, comp, sub_comp_3binariesClass_thresholds_dict[sub][comp])
                print('*'*50)
                print()
           
        
            print(sub, comp, ' three thresholds ', [lowVSmid_threshold.tolist(), midVShigh_threshold.tolist(), lowVShigh_threshold.tolist()])
            print('pred_list: ', pred_list)
            print('pred_list_ternary: ', pred_list_ternary)
            print('true_list:         ', true_list.tolist())
            
                   
            
        sub_to_pred_ternary[sub] = np.asarray(sub_to_pred_ternary[sub]).T
        
        
            
    return sub_to_pred_ternary

In [10]:
sub_to_true_ternary = convert_true_NumberToClass(sub_to_true, all_sub_list)

In [11]:
sub_to_pred_ternary = thresholding_rgrssn_for_clssftn(all_sub_list, all_comp_list, binary_list,sub_to_pred, sub_to_true_ternary, sub_comp_3binariesClass_thresholds_dict)

38A CHO  three thresholds  [[43.44513702392578], [142.1768341064453], [142.1768341064453]]
pred_list:  [ 21.047   65.3364 163.8477  24.5813 142.1768  43.4451  71.0525  57.819 ]
pred_list_ternary:  [0, 1, 2, 0, 1, 1, 1, 1]
true_list:          [0, 1, 2, 0, 2, 1, 1, 1]
38A pro  three thresholds  [[18.184051513671875], [28.435108184814453], [111.61946105957031]]
pred_list:  [ 16.8251  48.7387 111.6195  18.1841  94.767   28.4351  46.0694  37.2553]
pred_list_ternary:  [0, 1, 2, 1, 1, 1, 1, 1]
true_list:          [0, 1, 2, 1, 1, 2, 0, 1]
38A fat  three thresholds  [[52.62800979614258], [52.62800979614258], [183.19920349121094]]
pred_list:  [ 16.8542  52.628  183.1992  18.1627 138.7227  41.3168  57.6916  44.2187]
pred_list_ternary:  [0, 1, 1, 0, 1, 0, 1, 0]
true_list:          [0, 1, 2, 1, 1, 1, 1, 0]
38B CHO  three thresholds  [[29.402202606201172], [108.24459838867188], [108.24459838867188]]
pred_list:  [  0.4729 108.2446  12.7981  35.4536  54.5875  41.5431  29.4022]
pred_list_ternary:  [0, 

In [12]:
print_confusion_matrix(sub_to_true_ternary, sub_to_pred_ternary, all_sub_list)

38A

CHO:
[[2 0 0]
 [0 4 0]
 [0 1 1]]
pro:
[[1 1 0]
 [0 4 0]
 [0 1 1]]
fat:
[[2 0 0]
 [2 3 0]
 [0 1 0]]
38B

CHO:
[[1 0 0]
 [1 4 0]
 [0 1 0]]
pro:
[[0 1 0]
 [0 4 0]
 [0 1 1]]
fat:
[[1 0 0]
 [2 2 0]
 [1 1 0]]
38C

CHO:
[[2 0 0]
 [1 3 0]
 [0 1 1]]
pro:
[[2 0 0]
 [1 2 1]
 [1 1 0]]
fat:
[[1 1 0]
 [2 1 1]
 [0 1 1]]
38D

CHO:
[[2 0 0]
 [0 4 1]
 [0 1 1]]
pro:
[[1 0 1]
 [0 4 1]
 [0 1 1]]
fat:
[[2 0 0]
 [1 3 1]
 [0 1 1]]
38E

CHO:
[[2 0 0]
 [0 5 0]
 [0 1 1]]
pro:
[[1 1 0]
 [0 5 0]
 [0 2 0]]
fat:
[[2 0 0]
 [2 2 1]
 [0 0 2]]
38F

CHO:
[[1 1 0]
 [0 5 0]
 [0 1 0]]
pro:
[[1 1 0]
 [0 5 0]
 [0 1 0]]
fat:
[[0 2 0]
 [0 2 3]
 [0 1 0]]
38H

CHO:
[[1 0 0]
 [0 4 0]
 [0 0 2]]
pro:
[[1 0 0]
 [0 3 1]
 [0 2 0]]
fat:
[[2 0 0]
 [0 2 1]
 [0 2 0]]
CHO:
[[11  1  0]
 [ 2 29  1]
 [ 0  6  6]]
pro:
[[ 7  4  1]
 [ 1 27  3]
 [ 1  9  3]]
fat:
[[10  3  0]
 [ 9 15  7]
 [ 1  7  4]]
