In [1]:
import numpy as np
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve,ConfusionMatrixDisplay, classification_report
import pandas as pd

In [2]:
##we only need the sample table
D, tcga_maf, samples = pickle.load(open('/home/janaya2/Desktop/ATGC_paper/figures/tumor_classification/data/data.pkl', 'rb'))
del tcga_maf, D

# filtering the NCI-T labels (https://livejohnshopkins-my.sharepoint.com/:x:/r/personal/abaras1_jh_edu/_layouts/15/doc2.aspx?sourcedoc=%7B5f92f0fc-ec6c-40d5-ab17-0d3345f9f2c2%7D&action=edit&activeCell=%27Sheet1%27!B21&wdinitialsession=e072a38f-57c8-4c1f-885b-efaefcc81d35&wdrldsc=2&wdrldc=1&wdrldr=AccessTokenExpiredWarning%2CRefreshingExpiredAccessT)
ncit_labels_kept = ['Muscle-Invasive Bladder Carcinoma','Infiltrating Ductal Breast Carcinoma',
                    'Invasive Lobular Breast Carcinoma','Cervical Squamous Cell Carcinoma',
                    'Colorectal Adenocarcinoma','Glioblastoma','Head and Neck Squamous Cell Carcinoma',
                    'Clear Cell Renal Cell Carcinoma','Papillary Renal Cell Carcinoma','Astrocytoma',
                    'Oligoastrocytoma','Oligodendroglioma','Hepatocellular Carcinoma','Lung Adenocarcinoma',
                    'Lung Squamous Cell Carcinoma','Ovarian Serous Adenocarcinoma','Adenocarcinoma, Pancreas',
                    'Paraganglioma','Pheochromocytoma','Prostate Acinar Adenocarcinoma','Colorectal Adenocarcinoma',
                    'Desmoid-Type Fibromatosis','Leiomyosarcoma','Liposarcoma','Malignant Peripheral Nerve Sheath Tumor',
                    'Myxofibrosarcoma','Synovial Sarcoma','Undifferentiated Pleomorphic Sarcoma',
                    'Cutaneous Melanoma','Gastric Adenocarcinoma','Testicular Non-Seminomatous Germ Cell Tumor',
                    'Testicular Seminoma','Thyroid Gland Follicular Carcinoma','Thyroid Gland Papillary Carcinoma',
                    'Endometrial Endometrioid Adenocarcinoma','Endometrial Serous Adenocarcinoma']
ncit_samples = samples.loc[samples['NCI-T Label'].isin(ncit_labels_kept)]
PCPG_ncit = ['Paraganglioma','Pheochromocytoma']
SARC_ncit = ['Desmoid-Type Fibromatosis','Leiomyosarcoma','Liposarcoma','Malignant Peripheral Nerve Sheath Tumor',
             'Myxofibrosarcoma','Synovial Sarcoma','Undifferentiated Pleomorphic Sarcoma']
TGCT_ncit = ['Testicular Non-Seminomatous Germ Cell Tumor','Testicular Seminoma']
ncit_samples.loc[ncit_samples['NCI-T Label'].isin(PCPG_ncit), 'NCI-T Label'] = 'PCPG'
ncit_samples.loc[ncit_samples['NCI-T Label'].isin(SARC_ncit), 'NCI-T Label'] = 'SARC'
ncit_samples.loc[ncit_samples['NCI-T Label'].isin(TGCT_ncit), 'NCI-T Label'] = 'TGCT'
#print(ncit_samples.loc[ncit_samples['NCI-T Label'] == 'Testicular Seminoma']['NCI-T Label'])
#print(list(set(ncit_samples['NCI-T Label'])))

A = ncit_samples['NCI-T Label'].astype('category')
classes = A.cat.categories.values
##integer values for random forest
classes_onehot = np.eye(len(classes))[A.cat.codes]
y_label = classes_onehot

y_strat = np.argmax(y_label, axis=-1)
class_counts = dict(zip(*np.unique(y_strat, return_counts=True)))
y_weights = np.array([1 / class_counts[_] for _ in y_strat])
y_weights /= np.sum(y_weights)

##all the stratifications were the same so all test_idx should match up
test_idx, mil_predictions = pickle.load(open('/home/mlee276/Desktop/TCGA-ML-main/results/mil_contexts_predictions.pkl', 'rb'))
test_idx, nn_predictions = pickle.load(open('/home/mlee276/Desktop/TCGA-ML-main/results/nn_contexts_predictions.pkl', 'rb'))
test_idx, rf_predictions = pickle.load(open('/home/mlee276/Desktop/TCGA-ML-main/results/rf_contexts_predictions.pkl', 'rb'))

rf_predictions = np.asarray(np.vstack(rf_predictions))
correct = (y_strat[np.concatenate(test_idx)])
# one hot correct vals: [1,2,3] - > [[0,1,0,0],[0,0,1,0],[0,0,0,1]]
onehot = np.zeros((correct.size,correct.max()+1))
onehot[np.arange(correct.size),correct] = 1

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self._setitem_single_column(loc, value, pi)


In [5]:
# true must be onehot
# pred must be list of accuracies
# classNames must be list of classes
def plot_metrics(true, pred, classNames, notOneHot):
    %matplotlib
    
    f = plt.figure(figsize=(16,14), constrained_layout=True)
    gs = f.add_gridspec(6, 6)
    ax = dict()

    # ROC AUC Plot
    colors = []
    aucs = []
    ax['roc_auc'] = f.add_subplot(gs[0:2, 0:2]) #(2, 2, 1)
    colormap = plt.cm.nipy_spectral
    cycleColors = [colormap(i) for i in np.linspace(0,1,len(classes))]
    ax['roc_auc'].set_prop_cycle('color',cycleColors)
    for i in range(true.shape[1]):
         fpr, tpr, _ = roc_curve(true[:, i], pred[:, i])
         auc = roc_auc_score(true[:, i], pred[:, i])
         aucs.append(auc)
         ax['roc_auc'].plot(fpr, tpr, linewidth=0.5, label='%s (%.3f)' % (classNames[i], auc))
         colors.append(ax["roc_auc"].get_lines()[-1].get_color())
    ax['roc_auc'].set_title('ROC Curve')
    ax['roc_auc'].set_ylabel('True Positives')
    ax['roc_auc'].set_xlabel('False Positives')

    labelList = []
    for i in range(len(aucs)):
        labelList.append('%s (%.3f)' % (classNames[i], aucs[i]))

    # Confusion Matrix Plot
    ax['confusion_matrix'] = f.add_subplot(gs[3:5, 0:2]) #(2, 2, 2)
    conf_mat = confusion_matrix(true.argmax(axis=1), pred.argmax(axis=1))
    disp = ConfusionMatrixDisplay(confusion_matrix=conf_mat) 
    disp.plot(ax=ax['confusion_matrix'])
    disp.im_.colorbar.remove()
    ax['confusion_matrix'].set(yticks=np.arange(len(classNames)), yticklabels=labelList)
    [t.set_color(i) for (i,t) in zip((colors),ax["confusion_matrix"].yaxis.get_ticklabels())]
    ax['confusion_matrix'].set_title('confusion matrix')
    ax['confusion_matrix'].set_ylabel('True Classes')
    ax['confusion_matrix'].yaxis.set_label_position("right")
    ax['confusion_matrix'].set_xlabel('Predicted Classes')
    
    # Precision and Recall per Class Table
    ax['classification_report'] = f.add_subplot(gs[0:2,3:5])
    clf = classification_report(notOneHot, pred.argmax(axis=1),target_names=classNames,output_dict=True)
    sns.heatmap(pd.DataFrame(clf).iloc[:-1, :].T, annot=True, ax=ax['classification_report'])
    
    # Accuracy per Class Table
    ax['accuracies'] = f.add_subplot(gs[3:5,3:5])
    acc_per_class = conf_mat.diagonal()/conf_mat.sum(axis=0)
    mat = [[0]*2 for i in range(len(classNames))]
    k=0
    for i, j in zip(classNames, acc_per_class):
        mat[k][0]=i
        mat[k][1]='{0:.2f}'.format(j)
        k+=1
    column_labels=["Class","Accuracy"]
    ax['accuracies'].axis('tight')
    ax['accuracies'].axis('off')
    ax['accuracies'].table(cellText=mat,colLabels=column_labels,loc="center")

    # Our default 
    plt.subplots_adjust(
    top=0.969,
    bottom=0.0,
    left=0.048,
    right=1.0,
    hspace=0.0,
    wspace=0.0) 

    plt.show()


In [6]:
plot_metrics(onehot,mil_predictions,classes,correct)
#plot_metrics(onehot,rf_predictions,classes)

Using matplotlib backend: Qt5Agg


  plt.subplots_adjust(
findfont: Font family ['normal'] not found. Falling back to DejaVu Sans.


In [7]:
# true must be onehot
# pred must be list of accuracies
# classNames must be list of classes
def plot_CLFs(mil_pred, nn_pred, rf_pred, classNames, notOneHot):
    %matplotlib
    
    f = plt.figure(figsize=(16,12), constrained_layout=True)
    ax = dict()

    ax['mil'] = f.add_subplot(1, 3, 1)
    clf_mil = classification_report(notOneHot, mil_pred.argmax(axis=1), target_names=classNames, output_dict=True)
    sns.heatmap(pd.DataFrame(clf_mil).iloc[:-1, :].T, annot=True, ax=ax['mil'], cbar=False)
    ax['mil'].set_title('MIL')
    
    ax['nn'] = f.add_subplot(1, 3, 2)
    clf_nn = classification_report(notOneHot, nn_pred.argmax(axis=1), target_names=classNames, output_dict=True)
    sns.heatmap(pd.DataFrame(clf_nn).iloc[:-1, :].T, annot=True, ax=ax['nn'], cbar=False)
    ax['nn'].set_title('NN')
    
    ax['rf'] = f.add_subplot(1, 3, 3)
    clf_rf = classification_report(notOneHot, rf_pred.argmax(axis=1), target_names=classNames, output_dict=True)
    sns.heatmap(pd.DataFrame(clf_rf).iloc[:-1, :].T, annot=True, ax=ax['rf'], cbar=False)
    ax['rf'].set_title('RF')

    # Our default 
    plt.subplots_adjust(
    top=0.949,
    bottom=0.079,
    left=0.058,
    right=0.991,
    hspace=0.803,
    wspace=0.945) 

    plt.show()

plot_CLFs(mil_predictions,nn_predictions,rf_predictions,classes,correct)

Using matplotlib backend: Qt5Agg


  ax.figure.draw(ax.figure.canvas.get_renderer())
  ax.figure.draw(ax.figure.canvas.get_renderer())
  plt.subplots_adjust(


In [8]:
# true must be onehot
# pred must be list of accuracies
# classNames must be list of classes
def plot_conf_mat(true, mil_pred, nn_pred, rf_pred, classNames, notOneHot):
    %matplotlib
    
    f = plt.figure(figsize=(16,14), constrained_layout=True)
    ax = dict()
    
    font = {'family' : 'normal',
            'size'   : 7}
    plt.rc('font', **font)

    # Confusion Matrix Plot
    ax['confusion_matrix'] = f.add_subplot(1,3,1) #(2, 2, 2)
    conf_mat = confusion_matrix(true.argmax(axis=1), mil_pred.argmax(axis=1), normalize='true')
    conf_mat = np.asarray([[round(j*100) for j in i] for i in conf_mat])
    disp = ConfusionMatrixDisplay(confusion_matrix=conf_mat) 
    disp.plot(ax=ax['confusion_matrix'], cmap=plt.cm.Blues)
    disp.im_.colorbar.remove()
    ax['confusion_matrix'].set(yticks=np.arange(len(classNames)), yticklabels=classNames)
    #[t.set_color(i) for (i,t) in zip((colors),ax["confusion_matrix"].yaxis.get_ticklabels())]
    ax['confusion_matrix'].set_title('MIL (noramlized, percent)')
    ax['confusion_matrix'].set_ylabel('True Classes')
    ax['confusion_matrix'].yaxis.set_label_position("right")
    ax['confusion_matrix'].set_xlabel('Predicted Classes')
    
    ax['confusion_matrix'] = f.add_subplot(1,3,2) #(2, 2, 2)
    conf_mat = confusion_matrix(true.argmax(axis=1), nn_pred.argmax(axis=1), normalize='true')
    conf_mat = np.asarray([[round(j*100) for j in i] for i in conf_mat])
    disp = ConfusionMatrixDisplay(confusion_matrix=conf_mat) 
    disp.plot(ax=ax['confusion_matrix'], cmap=plt.cm.Blues)
    disp.im_.colorbar.remove()
    ax['confusion_matrix'].set(yticks=np.arange(len(classNames)), yticklabels=classNames)
    #[t.set_color(i) for (i,t) in zip((colors),ax["confusion_matrix"].yaxis.get_ticklabels())]
    ax['confusion_matrix'].set_title('NN (noramlized, percent)')
    ax['confusion_matrix'].set_ylabel('True Classes')
    ax['confusion_matrix'].yaxis.set_label_position("right")
    ax['confusion_matrix'].set_xlabel('Predicted Classes')
    
    ax['confusion_matrix'] = f.add_subplot(1,3,3) #(2, 2, 2)
    conf_mat = confusion_matrix(true.argmax(axis=1), rf_pred.argmax(axis=1), normalize='true')
    conf_mat = np.asarray([[round(j*100) for j in i] for i in conf_mat])
    disp = ConfusionMatrixDisplay(confusion_matrix=conf_mat) 
    disp.plot(ax=ax['confusion_matrix'], cmap=plt.cm.Blues)
    disp.im_.colorbar.remove()
    ax['confusion_matrix'].set(yticks=np.arange(len(classNames)), yticklabels=classNames)
    #[t.set_color(i) for (i,t) in zip((colors),ax["confusion_matrix"].yaxis.get_ticklabels())]
    ax['confusion_matrix'].set_title('RF (noramlized, percent)')
    ax['confusion_matrix'].set_ylabel('True Classes')
    ax['confusion_matrix'].yaxis.set_label_position("right")
    ax['confusion_matrix'].set_xlabel('Predicted Classes')
    
    # Our default 
    plt.subplots_adjust(
    top=0.969,
    bottom=0.0,
    left=0.048,
    right=0.95,
    hspace=0.0,
    wspace=0.25) 

    plt.show()


plot_conf_mat(onehot,mil_predictions, nn_predictions, rf_predictions,classes,correct)


Using matplotlib backend: Qt5Agg


  plt.subplots_adjust(


In [9]:
def plot_table(true, mil_pred, nn_pred, rf_pred, classNames, notOneHot):
    %matplotlib
    
    pred_data = [mil_pred, nn_pred, rf_pred]
    model_names = ["MIL", "NN", "RF"]
    
    accuracies = []
    precisions = []
    recalls = []
    AUCs = []
    
    for i in range(len(pred_data)):
        # Accuracies
        conf_mat = confusion_matrix(true.argmax(axis=1), pred_data[i].argmax(axis=1))
        acc_per_class = conf_mat.diagonal()/conf_mat.sum(axis=0)
        acc_per_class = ["%.2f" % value for value in acc_per_class]
        accuracies.append(acc_per_class)
        # Precision and Recall
        clf = classification_report(notOneHot, pred_data[i].argmax(axis=1),target_names=classNames,output_dict=True)
        p = []
        r = []
        a = [] #
        for j in clf:
            if (j!="accuracy" and j!="weighted avg" and j!="macro avg"):
                p_temp = clf[j]["precision"]
                r_temp = clf[j]["recall"]
                p_temp = "%.2f" % p_temp
                r_temp = "%.2f" % r_temp
                p.append(p_temp)
                r.append(r_temp)
        precisions.append(p)
        recalls.append(r)
        # AUC
        temp = []
        for j in range(onehot.shape[1]):
            temp.append(roc_auc_score(onehot[:, j], pred_data[i][:, j]))
        temp = ["%.2f" % value for value in temp]
        AUCs.append(temp)
    
    # Construct Table: Pandas Dataframe
    df = pd.DataFrame(list(zip(accuracies[2], accuracies[1], accuracies[0], 
                               precisions[2], precisions[1], precisions[0],
                               recalls[2], recalls[1], recalls[0],
                               AUCs[2], AUCs[1], AUCs[0])))
    columns = [('Accuracy','RF'), ('Accuracy','NN'), ('Accuracy','MIL'),
               ('Precision','RF'), ('Precision','NN'), ('Precision','MIL'),
               ('Recall','RF'), ('Recall','NN'), ('Recall','MIL'),
               ('AUC','RF'), ('AUC','NN'), ('AUC','MIL')]
    df.columns = pd.MultiIndex.from_tuples(columns)
    df.insert(loc=0, column='Cancer Type', value=classNames)
    
    # Visualize:
    display(df)
    '''fig,ax = plt.subplots()
    table = ax.table(cellText=df.values, colLabels=df.columns, loc='center')
    plt.show()'''
    
plot_table(onehot,mil_predictions, nn_predictions, rf_predictions,classes,correct)

Using matplotlib backend: Qt5Agg


Unnamed: 0_level_0,Cancer Type,Accuracy,Accuracy,Accuracy,Precision,Precision,Precision,Recall,Recall,Recall,AUC,AUC,AUC
Unnamed: 0_level_1,Unnamed: 1_level_1,RF,NN,MIL,RF,NN,MIL,RF,NN,MIL,RF,NN,MIL
0,"Adenocarcinoma, Pancreas",0.18,0.15,0.23,0.18,0.15,0.23,0.13,0.29,0.4,0.86,0.85,0.89
1,Astrocytoma,0.19,0.21,0.28,0.19,0.21,0.28,0.08,0.2,0.1,0.89,0.89,0.94
2,Cervical Squamous Cell Carcinoma,0.39,0.36,0.4,0.39,0.36,0.4,0.53,0.65,0.66,0.94,0.94,0.94
3,Clear Cell Renal Cell Carcinoma,0.38,0.46,0.5,0.38,0.46,0.5,0.57,0.41,0.55,0.94,0.95,0.96
4,Colorectal Adenocarcinoma,0.64,0.6,0.71,0.64,0.6,0.71,0.74,0.66,0.72,0.98,0.97,0.98
5,Cutaneous Melanoma,0.96,0.95,0.96,0.96,0.95,0.96,0.85,0.86,0.87,0.97,0.97,0.98
6,Endometrial Endometrioid Adenocarcinoma,0.65,0.6,0.67,0.65,0.6,0.67,0.5,0.56,0.54,0.93,0.94,0.95
7,Endometrial Serous Adenocarcinoma,0.29,0.13,0.14,0.29,0.13,0.14,0.05,0.29,0.45,0.85,0.86,0.89
8,Gastric Adenocarcinoma,0.7,0.56,0.66,0.7,0.56,0.66,0.47,0.57,0.57,0.92,0.93,0.94
9,Glioblastoma,0.37,0.48,0.55,0.37,0.48,0.55,0.6,0.42,0.52,0.92,0.93,0.94


In [4]:
# Uses TCGA labels

# Inspiration: https://www.nature.com/articles/s41467-019-13825-8/figures/2
def nature_plot(true, pred, classNames, notOneHot):
    %matplotlib
    
    cluster_rows = True

    fig, axs = plt.subplots(ncols=2, nrows=2, gridspec_kw= {'width_ratios':[8, .5], 'height_ratios':[0.5,8]})#dict(width_ratios=[1,4,0.2]))
    
    # Get class sizes
    class_sizes = [0]*len(classNames)
    for c in notOneHot:
        class_sizes[c] += 1
    temp = {classNames[i]: class_sizes[i] for i in range(len(classNames))}

    # Confusion Matrix 
    font = {'family' : 'normal',
            'size'   : 7}
    plt.rc('font', **font)
    ax_cm = axs[1,0]
    conf_mat_raw = confusion_matrix(true.argmax(axis=1), pred.argmax(axis=1), normalize='true')
    # reorder rows based on row clustering
    df = pd.DataFrame(conf_mat_raw)
    clustermap = sns.clustermap(df, col_cluster=False)
    reordered_rows = clustermap.dendrogram_row.reordered_ind
    reordered_row_clustered_labels = []
    for i in range(len(classNames)):
        reordered_row_clustered_labels.append(classNames[reordered_rows[i]])
    if cluster_rows == True:
        classNames = reordered_row_clustered_labels 
        conf_mat = confusion_matrix(true.argmax(axis=1), pred.argmax(axis=1), normalize='true', labels=reordered_rows)
        conf_mat = np.asarray([[round(j*100) for j in i] for i in conf_mat]) 
    else:
        #not clustered
        conf_mat = np.asarray([[round(j*100) for j in i] for i in conf_mat_raw]) 
    
    # Construct list with class names and sizes.
    classes_and_sizes = [0]*len(classNames)
    for i in range(len(classNames)):
        classes_and_sizes[i] = classNames[i] + " (" + str(temp[classNames[i]]) + ")"
    
    # Precision and Recall
    clf = classification_report(notOneHot, pred.argmax(axis=1),target_names=classNames,output_dict=True)
    precision = []
    recall = []
    for j in clf:
        if (j!="accuracy" and j!="weighted avg" and j!="macro avg"):
            precision.append(round(float("%.2f" % clf[j]["precision"])*100))
            recall.append(round(float("%.2f" % clf[j]["recall"])*100))
    p_temp = []
    r_temp = []
    if cluster_rows == True:
        for i in range(len(precision)): 
            p_temp.append(precision[reordered_rows[i]])
            r_temp.append(recall[reordered_rows[i]])
        precision = p_temp
        recall = r_temp           
    
    # To dataframe
    confusion_df = pd.DataFrame(conf_mat)
    precision_df = pd.DataFrame(precision) 
    recall_df = pd.DataFrame(recall).T 
    
    # Plotting
    sns.heatmap(confusion_df, annot=True, cbar=False, ax=axs[1,0], cmap=plt.cm.Blues)
    axs[1,0].set(yticks=np.arange(len(classNames)), yticklabels=classes_and_sizes, xticks=np.arange(len(classNames)), xticklabels=classNames)
    #axs[1,0].set_xticklabels(classNames, rotation=90)
    #axs[1,0].set_yticklabels(classes_and_sizes, rotation=0)
    axs[1,0].title.set_text('Confusion Matrix')
    sns.heatmap(precision_df, annot=True, yticklabels=False, cbar=False, ax=axs[1,1], cmap=plt.cm.Blues)
    axs[1,1].title.set_text('Precision')
    axs[1,1].set_xticks([])
    sns.heatmap(recall_df, annot=True, yticklabels=False, cbar=False, ax=axs[0,0], cmap=plt.cm.Blues)
    axs[0,0].title.set_text('Recall')
    axs[0,0].set_xticks([])
    
    # center tick marks
    div = [item + 0.5 for item in range(0, len(classes_and_sizes))]
    axs[1,0].set_yticklabels('') # Hide major tick labels
    axs[1,0].set_yticks(div,      minor=True) # Customize minor tick labels
    axs[1,0].set_yticklabels(classes_and_sizes, minor=True) 
    axs[1,0].set_xticklabels('', rotation=90) # Hide major tick labels
    axs[1,0].set_xticks(div,      minor=True) # Customize minor tick labels
    axs[1,0].set_xticklabels(classNames, minor=True, rotation=90)
    
    # extra formatting
    fig.delaxes(axs[0,1])
    
    fig.suptitle('MIL (with modified NCI-T Lables and context data)', fontsize=16)
    plt.show()
    fig.tight_layout()
    
nature_plot(onehot, nn_predictions,classes,correct)

Using matplotlib backend: Qt5Agg


findfont: Font family ['normal'] not found. Falling back to DejaVu Sans.
