In [1]:
import collections
from scipy import stats
from scipy.stats import ranksums
import pandas as pd
import numpy as np
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
import flowkit as fk
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from random import sample
import random
import time
# import xlrd
# from combat.pycombat import pycombat
import fcswrite
import re

from datetime import datetime
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_curve,auc

from sklearn.datasets import make_blobs
import colorcet as cc

In [2]:
def save_downsample_fcs(df_copy,mod_cell_label,in_path):
    df_copy['cluster']=mod_cell_label
    fcswrite.write_fcs(filename=in_path,chn_names=list(df_copy.columns),data=df_copy)
    print("downsample file saved")

def filter_infection(in_dfv_0,in_cell_cluster,in_dfv_0_arr,in_select_idx,in_get_column,in_fun_14_idx):
    if in_select_idx==3 or in_select_idx==4:
        visit_list=np.array(in_dfv_0[in_dfv_0['*Infection'] ==1].index.tolist()) ###

    elif in_select_idx==5 or in_select_idx==6:
        visit_list=np.array(in_dfv_0[in_dfv_0['*Infection'] ==2].index.tolist()) ###

    elif in_select_idx==8:
        visit_list=np.array(in_dfv_0[in_dfv_0['*Infection'] !=2].index.tolist()) ###
    in_dfv_0=in_dfv_0.iloc[visit_list,:] #column:35 ###
    in_cell_cluster=in_cell_cluster[visit_list]
    in_dfv_0=in_dfv_0.reset_index(drop=True)
    in_p_PTID=np.unique(in_dfv_0['*PTID'])
    in_dfv_0_arr=in_dfv_0_arr.iloc[visit_list,:]
    in_dfv_0_arr=in_dfv_0_arr.reset_index(drop=True)
    in_dfv_0_arr=in_dfv_0_arr[in_get_column[in_fun_14_idx]]
    return in_dfv_0,in_cell_cluster,in_dfv_0_arr,in_p_PTID


def plot_boxplot(in_sst,in_df_total,in_cluster_label,in_cluster_num,in_PTID,in_T_name,in_F_name,compare_index,in_boxplot_path,y_title):
    one_marker=in_df_total.iloc[:,np.array(in_sst)]
    one_marker_c=one_marker.columns
    one_marker_arr=np.array(one_marker)

    get_pvalue=[]
    get_compare_uplow=[]

    plt.figure(figsize=(32, 30))
    sns.set(font_scale=1.2)
    sns.set_style(style='white')
    
    infection_dfv_0=np.array(in_cluster_label)
    for cluster_i in range(0,in_cluster_num,1):
        data_T=[]
        data_F=[]
        for iid in range(0,len(in_PTID),1):
            get_val=infection_dfv_0[iid]
            if get_val==1: 
                data_T.append(one_marker_arr[iid][cluster_i])
            elif get_val==0:
                data_F.append(one_marker_arr[iid][cluster_i])
        lst_T = [ in_T_name[compare_index] for i in range(0, len(data_T))] 
        lst_F = [ in_F_name[compare_index] for i in range(0, len(data_F))]
        if compare_index<=2:
            if np.mean(data_T)>np.mean(data_F):
                get_compare_uplow.append('up')
            else:
                get_compare_uplow.append('low')
        else:
            if np.mean(data_T)>np.mean(data_F):
                get_compare_uplow.append('low')
            else:
                get_compare_uplow.append('up')
        get_pvalue.append(ranksums(data_T, data_F).pvalue)
        data_df = pd.DataFrame({'condition': np.array(lst_T+lst_F),y_title: np.array(data_T+data_F)})
        plt.subplot(5, 5, cluster_i+1)
        sns.boxplot(x = "condition",y = y_title, palette="Set3",data = data_df).set(title=one_marker_c[cluster_i])
        sns.stripplot(x = "condition",y = y_title,data = data_df)
    plt.savefig(in_boxplot_path, dpi=300)
    plt.show()
    return get_pvalue,get_compare_uplow

def tsne_initial(matrix_length,in_datetime_object,in_PTID,inall_cell_id,incell_label,expression_matrix):
    #
    tsne_cell_label=[]
    
    sub_trans=np.empty((0,matrix_length), float)

    # print("dm_0_trans cell_label",dm_0_trans.shape,len(cell_label))

    index_list=list()
    random.seed(in_datetime_object.timestamp())
    for single_sample in in_PTID: #1~n_clusters
        ii = np.where(np.array(inall_cell_id) == single_sample)[0]
        sampled_list = random.sample(list(ii), 1000)
        tsne_cell_label.extend(incell_label[sampled_list])
        index_list.extend(sampled_list)
        sub_trans=np.append(sub_trans, expression_matrix[sampled_list], axis=0)
    tsne = TSNE(n_components=2, perplexity=80,random_state=1024)
    in_all_embedded = tsne.fit_transform(sub_trans)
    return in_all_embedded,tsne_cell_label,index_list

def tsne_all_cluster(set_colors,tsne_df,in_tsne_path):
    sns.palplot(sns.color_palette(set_colors))

    plt.figure(figsize=(26, 26))
    sns.set(font_scale=4)
    sns.set_style(style='white')

    ax=sns.scatterplot(x='TSNE1', y='TSNE2', data=tsne_df, hue='target',palette=sns.color_palette(set_colors),legend=False, s=90)
    # for x_l in list(pca_df['target']):
    #     plt.annotate(str(x_l), pca_df.loc[pca_df['target']==str(x_l),['TSNE1','TSNE2']].mean(),horizontalalignment='center',verticalalignment='center',size=20, weight='bold')

    ax.set_xticks([])
    ax.set_yticks([])
    ax.set(xlabel='TSNE1')
    ax.set(ylabel='TSNE2')
    plt.savefig(in_tsne_path, dpi=300)
    plt.plot()

def tsne_each_cluster(tsne_df,in_all_embedded,in_tsne_path):
    counts=1
    sns.set(font_scale=1.5)
    sns.set_style(style='white')
    fig, axes = plt.subplots(5,5, figsize=(40, 35))
    for i in range(0,5,1):
        for j in range(0,5,1):
            trans_temp = np.copy(tsne_df['target'])
            trans_temp[trans_temp!=str(counts)]='others'
            trans_temp[trans_temp==str(counts)]='select'
            pplat=dict(others="#95a5a6",select="#3498db")
            pca_temp = pd.DataFrame(data=in_all_embedded, columns=['TSNE1', 'TSNE2'])
            pca_temp['target']=trans_temp
            sns.scatterplot(ax=axes[i,j],x='TSNE1',y='TSNE2',data=pca_temp,hue='target',palette=pplat,legend=True,s=3)
            axes[i,j].set_title("cluster "+str(counts))
            counts+=1
    plt.savefig(in_tsne_path, dpi=300)
    plt.plot()
    print("finish plot2")

def tsne_freq_phenotype(in_Frame,tsne_df,in_all_embedded,in_tsne_path):
    sns.set(font_scale=4)
    sns.set_style(style='white')
    fig, axes = plt.subplots(2,6, figsize=(72, 11), gridspec_kw={'height_ratios': [16,1]})
    for i in range(0,6,1):
        labels_cc=in_Frame.iloc[:,i]
        trans_temp = np.copy(tsne_df['target'])
        for singless in range(0,len(in_Frame),1):
            trans_temp[trans_temp==str(singless+1)]=labels_cc[singless]
        pca_temp = pd.DataFrame(data=in_all_embedded, columns=['TSNE1', 'TSNE2'])
        pca_temp['target']=trans_temp
        ax=sns.scatterplot(ax=axes[0,i],x='TSNE1',y='TSNE2',data=pca_temp,hue='target',palette="viridis_r",legend=False,s=3)#
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlabel('TSNE1', fontdict=dict(weight='bold'))
        ax.set_ylabel('TSNE2', fontdict=dict(weight='bold'))
        norm = plt.Normalize(pca_temp['target'].min(), pca_temp['target'].max())
        sm = plt.cm.ScalarMappable(cmap="viridis_r", norm=norm)
        sm.set_array([])
        cbar =ax.figure.colorbar(sm,orientation='horizontal',cax=axes[1,i])
        cbar.ax.tick_params(labelsize=50)
        for t in cbar.ax.xaxis.get_major_ticks():
            t.label1.set_fontweight('bold')
        get_phy_name=np.array(in_Frame.columns)[i]
        get_phy_name=get_phy_name.split("_")
        axes[0,i].set_title(get_phy_name[len(get_phy_name)-1], fontdict = { 'weight': 'bold'})
    plt.savefig(in_tsne_path, dpi=300)
    plt.plot()

def tsne_func_phenotype(in_Frame_14,in_fun_14,in_all_embedded,index_list,in_tsne_path):
    counts=0
    sns.set(font_scale=1.5)
    sns.set_style(style='white')
    fig, axes = plt.subplots(3,5, figsize=(60, 30))
    for i in range(0,3,1):
        for j in range(0,5,1):
            
            labels_cc=in_Frame_14[in_fun_14[counts]]
            pca_temp = pd.DataFrame(data=in_all_embedded, columns=['TSNE1', 'TSNE2'])
            pca_temp['target']= np.array(labels_cc)[np.array(index_list)]#trans_temp idx_lst
            sns.scatterplot(ax=axes[i,j],x='TSNE1',y='TSNE2',data=pca_temp,hue='target',palette="viridis",legend=True,s=5)
            axes[i,j].set_title(in_fun_14[counts])
            counts+=1
            if counts>12:
                break
    plt.savefig(in_tsne_path, dpi=300)
    plt.plot()
    print("finish plot4")

In [4]:
# TSNE color define
colors = [
"#FF0000",  # Red
"#FFA500",  # Orange
"#FFFF00",  # Yellow
"#008000",  # Green
"#00FFFF",  # Cyan
"#0000FF",  # Blue
"#FF00FF",  # Magenta
"#800080",  # Purple
"#FFC0CB",  # Pink
"#FF4500",  # Orange Red
"#FFD700",  # Gold
"#808000",  # Olive
"#008080",  # Teal
"#000080",  # Navy
"#FF1493",  # Deep Pink
"#FF69B4",  # Hot Pink
"#BA55D3",  # Medium Orchid
"#8A2BE2",  # Blue Violet
"#1E90FF",  # Dodger Blue
"#808080",  # Gray
"#4B0082",  # Indigo
"#7FFF00",  # Chartreuse
"#FFA07A",  # Light Salmon
"#800000",  # Maroon
"#20B2AA",  # Light Sea Green
]

In [5]:
all_pred=[] ###
all_true=[] ###
heatmap_pkl={}

file_name='post_combat_gated_CD4' #
cluster_num=25##change cluster number
compare_name=['Uninf','Uninf','CT only','Endo+','FU+','Endo+','FU+','Endo+','FU+']
infection_select_idx=[3,4,5,6,8]

min_thr_dict={"Dy161Di":0.03998934,"Tm169Di":0.09983408,"Eu153Di":0.16524692,"Yb172Di":0.53422407,"Nd150Di":0.19869011,"Sm154Di":0.19869011,"Gd155Di":0.04997919,"Ho165Di":0.35264533,"Gd158Di":0.39003532,"Gd160Di":0.39003532,"Yb171Di":0.07991491,"Yb174Di":0.09983408,"Eu151Di":0.05996406}
name_to_visit='v0_noH' ###
condition_lst=['*Infection','*Infection','*Infection','*Ascension','*FollowUp','*Ascension','*FollowUp','*Ascension','*FollowUp']
name_to_select=['inf01','inf02','inf12','inf1-asc01','inf1-fol01','inf2-asc01','inf2-fol01','asc01','fol01']
label_lst=[[0,1],[0,2],[1,2],[0,1],[0,1],[0,1],[0,1],[0,1],[0,1]]

fun_14=['161Dy_CD152_CTLA4', '151Eu_CD107a_LAMP1_and_Bead3' , '153Eu_CD185_CXCR5_and_Bead4', '155Gd_CD279_PD_1', '158Gd_CD27', '160Gd_CD28', '165Ho_CD127_and_Bead5', '150Nd_HLA_DR_DP','154Sm_CD71', '169Tm_CD25', '171Yb_CD195_CCR5', '172Yb_CD38', '174Yb_CD94'] # ,'164Dy_CD95'
phy_5=['143Nd_CD45RA', '167Er_CD197_CCR7','164Dy_CD95', '141Pr_CD196_CCR6', '156Gd_CD183_CXCR3', '149Sm_CD194_CCR4']

name_of_T=['CT only','Coinfected','Coinfected','inf1-Endo+','inf1-Followup Positive','inf2-Endo+','inf2-Followup Positive','Endo+','Followup Positive'] #'HD','HD','HD',
name_of_F=['Uninfected','Uninfected','CT only','inf1-Endo-','inf1-Followup Negative','inf2-Endo-','inf2-Followup Negative','Endo-','Followup Negative'] # 'Uninfected','CT only','Coinfected', ,'Non-ascending','Uninfected','Uninfected'

label_format={"174Yb_CD94":"CD94","161Dy_CD152_CTLA4":"CTLA4","169Tm_CD25":"CD25","171Yb_CD195_CCR5":"CCR5","165Ho_CD127_and_Bead5":"CD127","phenotypic":"Frequency","150Nd_HLA_DR_DP":"HLA-DR","158Gd_CD27":"CD27","155Gd_CD279_PD_1":"CD279(PD-1)","172Yb_CD38":"CD38","160Gd_CD28":"CD28","151Eu_CD107a_LAMP1_and_Bead3":"CD107a(LAMP1)","153Eu_CD185_CXCR5_and_Bead4":"CXCR5","154Sm_CD71":"CD71"}


if file_name=='post_combat_gated_CD4':
    pheon_label_name=['TN','TCM','TEM Th2','TEM Th2','TCM','TCM Th2','TEM','TN','TEM Th2','TN','TCM','TCM Th17','TEM Th1','Transitional TCM','TSCM','Transitional TCM Th1','TCM Th2','TCM Th1','TN','TN','TEM Th17 DN','TN','TCM','TEMRA','TEM']
else:
    pheon_label_name=['TN (CXCR3+)','TEM','TCM','TN (CXCR3+)','TEM Tc17 DN','TN','TCM (CXCR3+CCR4+)','TCM Tc1','TEMRA','TSCM','TCM Tc1','TCM Tc2','TN (CXCR3+)','TEMRA','TN','TSCM (CXCR3+)','TSCM','TEMRA (CXCR3+)','TN (CXCR3+)','TCM','TEMRA','TN','TN','TEM','TCM Tc1']


In [None]:

#random seeds define
timestamp_string = '2023-02-03T14::12::53'
format = '%Y-%m-%dT%H::%M::%S'
datetime_object = datetime.strptime(timestamp_string, format)
random.seed(datetime_object.timestamp())


#load file
save_name=file_name.split('_')[3]
df = pd.read_csv('/Users/cookie/Documents/UNC/research/vis/'+file_name+'_modify_noHD.csv')

all_cell_id=df['*PTID']
unique_ID=np.unique(df['*PTID'])
select_cell_per_ID=list()

#downsample per participate
for u_id in unique_ID:
    id_idx=np.where(np.array(all_cell_id) == u_id)[0]
    downsample_id=sample(list(id_idx), 15000)
    select_cell_per_ID.extend(downsample_id)

df=df.iloc[select_cell_per_ID]
df=df.reset_index(drop=True)
all_cell_id=df['*PTID']
unique_ID=np.unique(df['*PTID'])

pick_column=list(np.arange(0,31,1))
get_column=df.columns

#pick frequency features
phenotype=np.array([16,4,3,21,10,23])
fun_14_idx=[]

for i in fun_14:
    i_pnn=i.split('_')[0]
    i_pnn=i_pnn[3:5]+i_pnn[0:3]+'Di'
    fun_14_idx.append(list(get_column).index(i_pnn))


features_name=[]
for i in range(0,cluster_num,1):
    features_name.append("cluster "+str(i+1)+": phenotypic markers")
for i in range(0,cluster_num,1):
    for j in fun_14:
        features_name.append("cluster "+str(i+1)+": "+j)


#loop for different comparsion - the comparsion are list in name_to_select
for select_idx in range(0,1,1):

    dfv_0_arr=np.array(df)
    # arcsinh transfer
    dm_0_trans = np.arcsinh(1./5 * dfv_0_arr[:,pick_column])
    dm_to_df=pd.DataFrame(data=dm_0_trans,columns=get_column[pick_column])
    
    #select functional features
    for single_label in list(phenotype):
        pick_column.remove(single_label)
    pick_column=np.array(pick_column)

    # Kmean by frequency features
    dm_0_trans = np.array(dm_to_df[get_column[phenotype]])
    kmeans = KMeans(n_clusters=cluster_num, random_state=0).fit(dm_0_trans)
    cluster_centers_arr=np.array(kmeans.cluster_centers_)
    Frame=pd.DataFrame(cluster_centers_arr, columns = phy_5)
    Frame_14=pd.DataFrame(np.array(dm_to_df[get_column[fun_14_idx]]), columns = fun_14)
    cell_label=kmeans.labels_


    # save downsample fcs
    # save_downsample_fcs(df.copy(),[i+1 for i in cell_label],"/Users/cookie/Documents/UNC/research/vis/"+file_name+"_downsample_cluster.fcs")
    

    # # plot TSNE 
    # tane_file_path="/Users/cookie/Downloads/test/"+file_name+"TNES_data.csv"
    # matrix_len=len(get_column[phenotype])
    # # TSNE initialize
    # all_embedded,sub_cell_label,idx_lst=tsne_initial(matrix_len,datetime_object,unique_ID,all_cell_id,cell_label,dm_0_trans)

    # pca_df = pd.DataFrame(data=all_embedded, columns=['TSNE1', 'TSNE2'])
    # pca_df['target']=[str(x+1) for x in sub_cell_label]
    # pca_df['index']=idx_lst
    # pca_df.to_csv("/Users/cookie/Downloads/test/"+file_name+"TNES_data.csv")

    # # plot TNSE for all cluster (figure 5B)
    # tsne_plot_path="/Users/cookie/Downloads/test/"+save_name+"_allcluster.png"
    # tsne_all_cluster(colors,pca_df,tsne_plot_path)

    # #plot TNSE for each cluster
    # tsne_each_cluster(pca_df,all_embedded,tsne_plot_path)
    # tsne_plot_path="/Users/cookie/Downloads/rerun_vis/"+save_name+"_cluster_phenotype.png"

    # #plot TNSE for frequency features  (figure 5A)
    # tsne_plot_path="/Users/cookie/Downloads/rerun_vis/"+save_name+"_markers_phenotype.png"
    # tsne_freq_phenotype(Frame,pca_df,all_embedded,tsne_plot_path)

    # #plot TNSE for functional features
    # tsne_plot_path="/Users/cookie/Downloads/GC/fun_freq/"+save_name+"_14func_phenotype.png"
    # tsne_func_phenotype(Frame_14,fun_14,all_embedded,idx_lst,tsne_plot_path)



    cell_cluster=kmeans.labels_

    # define enrollment or one month
    visit_list=np.array(df[(df['*Visit'] ==0)].index.tolist()) ### |(df['*Visit'] ==2)
    check_sum=0
    cluster_matrix = np.empty((0,cluster_num+(len(fun_14)*cluster_num)), float)
    dfv_0=df.iloc[visit_list,:] 
    cell_cluster=cell_cluster[visit_list] ###
    dfv_0=dfv_0.reset_index(drop=True)
    p_PTID=np.unique(dfv_0['*PTID'])
    dfv_0_arr=dm_to_df.iloc[visit_list,:]
    dfv_0_arr=dfv_0_arr.reset_index(drop=True)
    dfv_0_arr=dfv_0_arr[get_column[fun_14_idx]]

    # filter data if the comparsion only look into specific condition
    if select_idx in infection_select_idx:
        dfv_0,cell_cluster,dfv_0_arr,p_PTID=filter_infection(dfv_0,cell_cluster,dfv_0_arr,select_idx,get_column,fun_14_idx)

    # print("fliter condition chape check:",dfv_0.shape,len(cell_cluster),dfv_0_arr.shape)

    #create vector per participant
    condition_select=condition_lst[select_idx]
    infection_dfv_0=np.array(dfv_0[condition_select])
    label_for_condition=label_lst[select_idx]
    cluster_label=[]
    get_PTID=[]
    for iid in p_PTID:  
        freq=[]
        ii = np.where(dfv_0['*PTID'] == iid)[0]
        # print(iid,np.unique(np.array(infection_dfv_0[ii])))
        get_val=np.unique(np.array(infection_dfv_0[ii]))[0]

        if get_val in label_for_condition: #or get_val==label_for_condition[2]
            get_PTID.append(iid)
            if (len(label_for_condition)==3):
                if get_val==label_for_condition[1] or get_val==label_for_condition[2]: # or get_val==label_for_condition[2]
                    cluster_label.append(1)
                elif get_val==label_for_condition[0]:
                    cluster_label.append(0)
            elif(len(label_for_condition)==2):
                if get_val==label_for_condition[1]:
                    cluster_label.append(1)
                elif get_val==label_for_condition[0]:
                    cluster_label.append(0)
            for cluster_i in range(0,cluster_num,1):
                count=0
                for index_ii in ii:
                    if (cluster_i)==cell_cluster[index_ii]:##change cluster number
                        count+=1
                freq.append(count)
            check_sum+=sum(freq)

            by_patient_matrix=dfv_0_arr.iloc[ii,:]
            by_patient_cluster_label=cell_cluster[ii]
            function_matrix=list()
            for single_cluster in range(0,cluster_num,1):
                cluster_sub_iid=np.where(by_patient_cluster_label == single_cluster)[0]
                if len(cluster_sub_iid)>1:
                    single_cluster_matrix=by_patient_matrix.iloc[cluster_sub_iid,:]
                    # print("single_cluster_matrix",single_cluster_matrix.shape,single_cluster_matrix.columns)
                else:
                    single_cluster_matrix=np.array([[0 for i in range(0,len(by_patient_matrix.iloc[0,:]),1)]]) 
                    single_cluster_matrix = pd.DataFrame(single_cluster_matrix, columns=by_patient_matrix.columns)
                filter_matrix = pd.DataFrame()
                for in_thr in list(single_cluster_matrix.columns):
                    dt=np.array(single_cluster_matrix[in_thr])
                    dt[dt<=min_thr_dict[in_thr]] = min_thr_dict[in_thr]
                    filter_matrix[in_thr]=dt
                if list(single_cluster_matrix.columns)!=list(filter_matrix.columns):
                    print('error in matrix!!')
                mean_data=np.quantile(np.array(filter_matrix), 0.75, axis=0)
                function_matrix.extend(mean_data)
            freq_feature=np.array(freq)/sum(freq)
            all_marker=np.concatenate((freq_feature, np.array(function_matrix)), axis=0)
            cluster_matrix = np.append(cluster_matrix, [all_marker], axis=0)


    sub_matrix=np.array(cluster_matrix)

    gini_matrix = np.empty((0,len(features_name)), float)

    # train data
    inf2_fol01_pred=[] ###
    inf2_fol01_true=[] ###
    collect_all_auc=[]
    random_state = np.random.RandomState(0)
    for i in range(0,30,1):
        clf = RandomForestClassifier(random_state=random_state)
        cv = StratifiedKFold(n_splits=5,shuffle=True)
        single_time_prob=[]
        single_time_true=[]

        for train,test in cv.split(sub_matrix,np.array(cluster_label)):
            clf.fit(sub_matrix[train],np.array(cluster_label)[np.array(train)])
            gini=clf.feature_importances_
            gini_matrix = np.append(gini_matrix, [gini], axis=0)
            lr_probs = clf.predict_proba(sub_matrix[test])
            lr_probs = lr_probs[:, 1]
            inf2_fol01_pred+=list(lr_probs) ###
            inf2_fol01_true+=list(np.array(cluster_label)[np.array(test)]) ###
            single_time_prob+=list(lr_probs)
            single_time_true+=list(np.array(cluster_label)[np.array(test)])
        fpr, tpr, t = roc_curve(single_time_true, single_time_prob)
        collect_all_auc.append(auc(fpr, tpr))

    all_pred.append(inf2_fol01_pred)
    all_true.append(inf2_fol01_true)
    auc_csv_save = pd.DataFrame(data=collect_all_auc, columns=['AUC_result'])
    auc_csv_save.to_csv(save_name+"_"+name_to_visit+"_"+name_to_select[select_idx]+"_auc_record.csv")

    #features importance plot
    plt.figure(figsize=[40,120])
    gini_matrix=np.mean(gini_matrix, axis=0)
    sorted_idx = gini_matrix.argsort()
    plt.barh(np.array(features_name)[sorted_idx], gini_matrix[sorted_idx])
    plt.xlabel("Random Forest Feature Importance")
    plt.title(file_name+" "+condition_select+" "+str(label_for_condition))
    plt.savefig(save_name+"_feature_"+name_to_visit+"_"+name_to_select[select_idx]+".png", dpi=300)
    plt.plot()

    #csv file for features_rank
    features_rank=np.argsort(np.argsort(-gini_matrix[sorted_idx]))+1
    feature_rank_csv = pd.DataFrame(data=np.array(features_name)[sorted_idx], columns=['features title'])
    feature_rank_csv['gini_score']=gini_matrix[sorted_idx]
    feature_rank_csv['rank']=features_rank
    feature_rank_csv.to_csv(save_name+"_"+name_to_visit+"_"+name_to_select[select_idx]+"_features_rank.csv")

    
    
    #boxplot initial
    top_5_hit=[i.split(": ")[1] for i in np.array(features_name)[sorted_idx]]
    top_5_hit=list(np.unique(top_5_hit))
    if 'phenotypic markers' in top_5_hit:
        top_5_hit.remove('phenotypic markers')
    top_5_hit_idx=[fun_14.index(i) for i in top_5_hit]

    df_total=pd.DataFrame(data=cluster_matrix, columns=features_name)
    pvalue_csv=pd.DataFrame(data=[i+1 for i in range(0,25,1)], columns=['cluster'])
    upper_lower_csv=pd.DataFrame(data=[i+1 for i in range(0,25,1)], columns=['cluster'])

    #upper quartile expression boxplot for functional features
    for pick_idx in range(0,len(top_5_hit_idx),1):
        sst=[]
        for i in range(cluster_num+top_5_hit_idx[pick_idx],len(df_total.columns),len(fun_14)): 
            sst.append(i)
        boxplot_path=save_name+"_"+name_to_visit+"_"+name_to_select[select_idx]+"_"+top_5_hit[pick_idx]+".png"
        y_axis_title="upper quartile expression"
        get_p,get_compare_up_low=plot_boxplot(sst,df_total,cluster_label,cluster_num,get_PTID,name_of_T,name_of_F,select_idx,boxplot_path,y_axis_title)
        upper_lower_csv[top_5_hit[pick_idx]]=get_compare_up_low
        pvalue_csv[top_5_hit[pick_idx]]=get_p

    #frequency boxplot for clusters
    sst=[]
    for i in range(0,cluster_num,1):
        sst.append(i)
    boxplot_path=save_name+"_"+name_to_visit+"_"+name_to_select[select_idx]+"_phenotype.png"
    y_axis_title="frequency"
    get_p,get_compare_up_low=plot_boxplot(sst,df_total,cluster_label,cluster_num,get_PTID,name_of_T,name_of_F,select_idx,boxplot_path,y_axis_title)
    pvalue_csv['phenotype']=get_p
    upper_lower_csv['phenotype']=get_compare_up_low

    #csv for allfeatures_rank
    gini_name_lst=np.array(feature_rank_csv["features title"])
    cluster_number_lst,arrow_compare_to_ctrl,total_p=[],[],[]

    unique_name=[re.split(": | ",i)[2] for i in gini_name_lst]
    unique_name=list(np.unique(unique_name))
    features_idx_dict={}

    for cluster_nam in range(0,len(gini_name_lst),1):
        get_per_name=re.split(": | ",gini_name_lst[cluster_nam])[1:3]
        cluster_number_lst.append(get_per_name[0])
        if get_per_name[1] not in features_idx_dict:
            temp_list=[]
        else:
            temp_list=features_idx_dict[get_per_name[1]]
        temp_list.append(cluster_nam)
        features_idx_dict[get_per_name[1]]=temp_list

        if get_per_name[1]=='phenotypic':
            arrow_compare_to_ctrl.append(upper_lower_csv['phenotype'][int(get_per_name[0])-1])
            total_p.append(pvalue_csv['phenotype'][int(get_per_name[0])-1])
        else:
            arrow_compare_to_ctrl.append(upper_lower_csv[get_per_name[1]][int(get_per_name[0])-1])
            total_p.append(pvalue_csv[get_per_name[1]][int(get_per_name[0])-1])
    feature_rank_csv['p_value']=total_p
    feature_rank_csv['arrow_compare_to_ctrl']=arrow_compare_to_ctrl

    Functional_Marker,Avg_Gini,Gini_score_list,Cluster,P_Value,Upper_or_Lower=[],[],[],[],[],[]

    for feature_nam in unique_name:
        Functional_Marker.append(feature_nam)
        Avg_Gini.append(np.mean(np.array(feature_rank_csv['gini_score'])[features_idx_dict[feature_nam]]))
        get_sort_list=np.argsort(np.array(feature_rank_csv['p_value'])[features_idx_dict[feature_nam]])
        
        get_temp=np.array(feature_rank_csv['p_value'])[features_idx_dict[feature_nam]]
        get_sort_temp=get_temp[get_sort_list]
        keep_idx=1
        for get_idxs in range(0,len(get_sort_temp),1):
            if get_sort_temp[get_idxs]>0.05:
                keep_idx=get_idxs
                break

        P_Value.append(get_temp[get_sort_list][:keep_idx])
        get_temp=np.array(cluster_number_lst)[features_idx_dict[feature_nam]]
        Cluster.append(get_temp[get_sort_list][:keep_idx])

        get_temp=np.array(feature_rank_csv['gini_score'])[features_idx_dict[feature_nam]]
        Gini_score_list.append(get_temp[get_sort_list][:keep_idx])

        get_temp=np.array(feature_rank_csv['arrow_compare_to_ctrl'])[features_idx_dict[feature_nam]]
        Upper_or_Lower.append(get_temp[get_sort_list][:keep_idx])

    
    get_sort_list=list(reversed(np.argsort(Avg_Gini)))
    Functional_Marker=np.array(Functional_Marker)[get_sort_list]
    Avg_Gini=np.array(Avg_Gini)[get_sort_list]
    Gini_score_list=[Gini_score_list[i] for i in get_sort_list]
    Cluster=[Cluster[i] for i in get_sort_list]
    P_Value=[P_Value[i] for i in get_sort_list]
    Upper_or_Lower=[Upper_or_Lower[i] for i in get_sort_list]


    new_Functional_Marker,new_Avg_Gini,new_Gini_score_list,new_Cluster,new_P_Value,new_Upper_or_Lower=list(),list(),list(),list(),list(),list()
    for i_fm in range(0,len(Functional_Marker),1):
        for j_pv in range(0,len(P_Value[i_fm]),1):
            new_Functional_Marker.append(label_format[Functional_Marker[i_fm]])
            new_Avg_Gini.append(round(Avg_Gini[i_fm],4))
            new_Gini_score_list.append(round(Gini_score_list[i_fm][j_pv],4))
            new_Cluster.append(Cluster[i_fm][j_pv])
            new_P_Value.append(P_Value[i_fm][j_pv])
            new_Upper_or_Lower.append(Upper_or_Lower[i_fm][j_pv])


    pheon_cluster_name=[pheon_label_name[int(clister_i)-1] for clister_i in new_Cluster]
    p_range=[[0,0.0001],[0.0001,0.001],[0.001,0.01],[0.01,0.05],[0.05,1]]
    p_range_label=['****P<0.0001','***P<0.001','**P<0.01','*P<0.05','None']
    newp_label=[]
    for every_p_score in new_P_Value:
        for in_p_range in range(0,len(p_range),1):
            if every_p_score>=p_range[in_p_range][0] and every_p_score<p_range[in_p_range][1]:
                newp_label.append(p_range_label[in_p_range])

    final_csv=pd.DataFrame(data=new_Functional_Marker, columns=['Marker'])
    final_csv['Avg Gini Score']=new_Avg_Gini
    final_csv['Gini Score']=new_Gini_score_list
    final_csv['Cluster']=new_Cluster
    final_csv['Phenotype']=pheon_cluster_name
    final_csv['P_Value']=newp_label
    final_csv['Compared to '+compare_name[select_idx]]=new_Upper_or_Lower
    final_csv.to_csv(save_name+"_"+name_to_visit+"_"+name_to_select[select_idx]+"_allfeatures_rank.csv",index=False)

    #save heatmap data
    top_5_hit_idx=[i for i in range(0,13,1)]
    kee_fea_name=[]
    cluster_matrix = np.empty((0,cluster_num), float)
    for pick_idx in range(0,len(top_5_hit_idx),1):
        sst=[]
        for i in range(cluster_num+top_5_hit_idx[pick_idx],len(df_total.columns),len(fun_14)): 
            sst.append(i)
        one_marker=df_total.iloc[:,np.array(sst)]
        one_marker_c=one_marker.columns
        one_marker_arr=np.array(one_marker)
        one_marker_arr=one_marker_arr.reshape((len(one_marker_arr[0]),len(one_marker_arr)))
        plot_title=one_marker_c[0]
        plot_title=plot_title.split(": ")[1]
        label_name=[re.split(": | ",i_name)[1] for i_name in one_marker_c]
        kee_fea_name.append(label_format[plot_title])

        clunter_num=[int(re.split(": | ",i)[1]) for i in feature_rank_csv['features title']]
        sub_feature_name=[re.split(": | ",i)[2] for i in feature_rank_csv['features title']]
        feature_rank_modify=feature_rank_csv.copy()
        feature_rank_modify["Marker"]=sub_feature_name
        feature_rank_modify["Cluster"]=clunter_num

        select_final_csv=feature_rank_modify[feature_rank_modify['Marker']==plot_title]
        select_final_csv=select_final_csv.sort_values(by=['Cluster'])
        dir_list=[-1 if i_dir=='low' else 1 for i_dir in select_final_csv['arrow_compare_to_ctrl']]
        correct_score_list=[dir_list[i_dir]*list(select_final_csv['gini_score'])[i_dir] for i_dir in range(0,len(select_final_csv['arrow_compare_to_ctrl']),1)]
        cluster_matrix = np.append(cluster_matrix, [correct_score_list], axis=0)

    clunter_num=[int(re.split(": | ",i)[1]) for i in feature_rank_csv['features title']]
    sub_feature_name=[re.split(": | ",i)[2] for i in feature_rank_csv['features title']]
    feature_rank_modify=feature_rank_csv.copy()
    feature_rank_modify["Marker"]=sub_feature_name
    feature_rank_modify["Cluster"]=clunter_num
    select_final_csv=feature_rank_modify[feature_rank_modify['Marker']=='phenotypic']
    select_final_csv=select_final_csv.sort_values(by=['Cluster'])
    dir_list=[-1 if i_dir=='low' else 1 for i_dir in select_final_csv['arrow_compare_to_ctrl']]
    correct_score_list=[dir_list[i_dir]*list(select_final_csv['gini_score'])[i_dir] for i_dir in range(0,len(select_final_csv['arrow_compare_to_ctrl']),1)]
    cluster_matrix = np.append(cluster_matrix, [correct_score_list], axis=0)
    kee_fea_name.append('Frequency')

    kee_cluster_name=[ str(i+1)+': '+str(pheon_label_name[i]) for i in range(0,len(np.array(pheon_label_name)),1)]
    heatmap_data=pd.DataFrame(cluster_matrix,
                    index=kee_fea_name, columns=kee_cluster_name)
    heatmap_pkl[name_to_select[select_idx]]=heatmap_data



dict = {'inf01_pred':all_pred[0],'inf01_true':all_true[0],'inf02_pred':all_pred[1],'inf02_true':all_true[1],'inf12_pred':all_pred[2],'inf12_true':all_true[2],'inf1_asc01_pred':all_pred[3],'inf1_asc01_true':all_true[3],'inf1_fol01_pred':all_pred[4],'inf1_fol01_true':all_true[4],'inf2_asc01_pred':all_pred[5],'inf2_asc01_true':all_true[5],'inf2_fol01_pred':all_pred[6],'inf2_fol01_true':all_true[6],'asc01_pred':all_pred[7],'asc01_true':all_true[7]}

f_rf = open(save_name+"_"+name_to_visit+"_dict.pkl","wb")
f_hm = open(save_name+"_"+name_to_visit+"_heatmap_df.pkl","wb")

import pickle
pickle.dump(heatmap_pkl,f_hm)
pickle.dump(dict,f_rf)

# close file
f_hm.close()
f_rf.close()

In [None]:
#plot heatmap (figure 6A)
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import pickle

name_of_T=['CT only','Coinfected','Coinfected','CT only-Endo+','CT only-Followup Positive','Coinfected-Endo+','Coinfected-Followup Positive'] #'HD','HD','HD',
name_of_F=['Uninfected','Uninfected','CT only','CT only-Endo-','CT only-Followup Negative','Coinfected-Endo-','Coinfected-Followup Negative'] 

with open("CD4_v0_noH_heatmap_df.pkl", 'rb') as f:
    cd4_0 = pickle.load(f)
with open("CD4_v1_noH_heatmap_df.pkl", 'rb') as f:
    cd4_1 = pickle.load(f)
with open("CD8_v0_noH_heatmap_df.pkl", 'rb') as f:
    cd8_0 = pickle.load(f)
with open("CD8_v1_noH_heatmap_df.pkl", 'rb') as f:
    cd8_1 = pickle.load(f)
    
name_to_select=['inf01','inf02','inf12','inf1-asc01','inf1-fol01','inf2-asc01','inf2-fol01']
idx_select=6
fig, axn = plt.subplots(2, 2, sharey=True,figsize=(73,45), constrained_layout=True)

if idx_select<=2:
    fig.suptitle(name_of_F[idx_select]+" v.s. "+name_of_T[idx_select] ,fontsize=80)
else:
    fig.suptitle(name_of_T[idx_select]+" v.s. "+name_of_F[idx_select] ,fontsize=80)

row_cluster=[13,12,9,0,7,4,6,10,5,3,11,8,1,2]
sns.set(font_scale=6.5)
cbar_ax = fig.add_axes([1.025, .3, .02, .5])
chose_max_min=[]
chose_max_min.append(np.array(cd4_0[name_to_select[idx_select]]))
chose_max_min.append(np.array(cd4_1[name_to_select[idx_select]]))
chose_max_min.append(np.array(cd8_0[name_to_select[idx_select]]))
chose_max_min.append(np.array(cd8_1[name_to_select[idx_select]]))
chose_max_min=np.array(chose_max_min)
chose_max_min=chose_max_min.flatten()
if abs(max(chose_max_min))> abs(min(chose_max_min)):
    chose_max_min=[-1*abs(max(chose_max_min)),abs(max(chose_max_min))]
else:
    chose_max_min=[-1*abs(min(chose_max_min)),abs(min(chose_max_min))]

for i, ax in enumerate(axn.flat):
    if i ==0:
        heatmap_data=cd4_0[name_to_select[idx_select]]
        title_set='CD4 Enrollment'
        get_xlabel=np.array(['1: TN (RAhi R7mid)','2: TCM (CXCR3+ CCR4mid)','3: TEM Th2','4: TEM Th2','5: TCM','6: TCM Th2','7: TEM','8: TN (RAmid R7+)','9: TEM Th2','10: TN (RA+ R7lo)','11: TCM','12: TCM Th17','13: TEM Th1','14: Transitional TCM','15: TSCM','16: Transitional TCM','17: TCM Th2','18: TCM Th1','19: TN (RAlo R7hi)','20: TN (RAhi R7hi)','21: TEM Th17 DN','22: TN (RA+ R7hi)','23: TCM','24: TEMRA','25: TEM'])
        fix_idx=np.array([20, 1, 22, 8, 10, 19, 15, 14, 16, 5, 11, 23, 18, 17, 6, 2, 12, 25, 7, 13, 3, 4, 9, 21, 24])
        fix_idx=fix_idx-1
        # print(0,heatmap_data.columns)
    elif i ==2:
        heatmap_data=cd4_1[name_to_select[idx_select]]
        title_set='CD4 1 Month'
        get_xlabel=np.array(['1: TN (RAhi R7mid)','2: TCM (CXCR3+ CCR4mid)','3: TEM Th2','4: TEM Th2','5: TCM','6: TCM Th2','7: TEM','8: TN (RAmid R7+)','9: TEM Th2','10: TN (RA+ R7lo)','11: TCM','12: TCM Th17','13: TEM Th1','14: Transitional TCM','15: TSCM','16: Transitional TCM','17: TCM Th2','18: TCM Th1','19: TN (RAlo R7hi)','20: TN (RAhi R7hi)','21: TEM Th17 DN','22: TN (RA+ R7hi)','23: TCM','24: TEMRA','25: TEM'])
        fix_idx=np.array([20, 1, 22, 8, 10, 19, 15, 14, 16, 5, 11, 23, 18, 17, 6, 2, 12, 25, 7, 13, 3, 4, 9, 21, 24])
        fix_idx=fix_idx-1
        # print(2,heatmap_data.columns)
    elif i==1:
        heatmap_data=cd8_0[name_to_select[idx_select]]
        title_set='CD8 Enrollment'
        get_xlabel=np.array(['1: TN (RAhi R7hi CXCR3lo)','2: TEM','3: TEMRA','4: TN (RA+ R7+ CXCR3+)','5: TEM Tc17 DN','6: TN (RAhi R7+)','7: TCM (CXCR3+ CCR4+)','8: TCM Tc1','9: TEMRA','10: TSCM','11: TCM Tc1','12: TCM Tc2','13: TN (RAhi R7hi CXCR3+)','14: TEMRA','15: TN (RAhi R7hi)','16: TSCM (CXCR3+)','17: TSCM','18: TEMRA (CXCR3+)','19: TN(RAhi R7+ CXCR3lo)','20: TCM','21: TEMRA','22: TN (RA+ R7hi)','23: TN (RAmid R7+)','24: TEM','25: TCM Tc1'])
        fix_idx=np.array([15, 6, 22, 23, 1, 19, 4, 13, 17, 10, 16, 20, 25, 8, 11, 12, 7, 2, 24, 5, 9, 21, 14, 3, 18])
        fix_idx=fix_idx-1
        # print(1,heatmap_data.columns)
    elif i==3:
        heatmap_data=cd8_1[name_to_select[idx_select]]
        title_set='CD8 1 Month'
        get_xlabel=np.array(['1: TN (RAhi R7hi CXCR3lo)','2: TEM','3: TEMRA','4: TN (RA+ R7+ CXCR3+)','5: TEM Tc17 DN','6: TN (RAhi R7+)','7: TCM (CXCR3+ CCR4+)','8: TCM Tc1','9: TEMRA','10: TSCM','11: TCM Tc1','12: TCM Tc2','13: TN (RAhi R7hi CXCR3+)','14: TEMRA','15: TN (RAhi R7hi)','16: TSCM (CXCR3+)','17: TSCM','18: TEMRA (CXCR3+)','19: TN(RAhi R7+ CXCR3lo)','20: TCM','21: TEMRA','22: TN (RA+ R7hi)','23: TN (RAmid R7+)','24: TEM','25: TCM Tc1'])
        fix_idx=np.array([15, 6, 22, 23, 1, 19, 4, 13, 17, 10, 16, 20, 25, 8, 11, 12, 7, 2, 24, 5, 9, 21, 14, 3, 18])
        fix_idx=fix_idx-1
        # print(3,heatmap_data.columns)
    axs=sns.heatmap(heatmap_data.iloc[row_cluster, fix_idx],ax=ax, annot=False,cmap="vlag",linewidths=0.03,linecolor="white",vmin=chose_max_min[0],vmax=chose_max_min[1],cbar_ax=None if i else cbar_ax,cbar=i == 0)
    print(get_xlabel[fix_idx])
    if i==2:
        axs.set_xticklabels(get_xlabel[fix_idx], rotation = 45, ha="right",fontsize=60)
    elif i==3:
        axs.set_xticklabels(get_xlabel[fix_idx], rotation = 45, ha="right",fontsize=60)
    else:
        axs.set_xticklabels([])
    col_title=heatmap_data.index
    print("col_title",type(col_title),len(col_title),col_title[row_cluster])
    axs.set_yticklabels(col_title[row_cluster], rotation=0, ha="right",fontsize=60)
    axs.set_title(title_set,fontsize=60)



if idx_select<=2:
    plt.savefig(name_of_F[idx_select]+"_"+name_of_T[idx_select]+"_heatmap.png", dpi=300,bbox_inches = 'tight')
else: 
    plt.savefig(name_of_T[idx_select]+"_"+name_of_F[idx_select]+"_heatmap.png", dpi=300,bbox_inches = 'tight')
plt.show()

In [None]:
#plot freq_inf roc curve (figure 6A)
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve,auc
import pickle


with open("CD4_v0_noH_dict.pkl", 'rb') as f:
    cd4_0_new_dict = pickle.load(f)
with open("CD8_v0_noH_dict.pkl", 'rb') as f:
    cd8_0_new_dict = pickle.load(f)
with open("CD4_v1_noH_dict.pkl", 'rb') as f:
    cd4_1_new_dict = pickle.load(f)
with open("CD8_v1_noH_dict.pkl", 'rb') as f:
    cd8_1_new_dict = pickle.load(f)



fig1 = plt.figure(figsize=[12,12])
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
sns.set(font_scale=2.5)
sns.set_style(style='white')

fpr, tpr, t = roc_curve(cd4_0_new_dict['inf01_true'], cd4_0_new_dict['inf01_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD4 0-Uninf vs CT Only  (AUC = %0.2f)' % (roc_auc), color='#3D5A80')

fpr, tpr, t = roc_curve(cd4_0_new_dict['inf02_true'], cd4_0_new_dict['inf02_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD4 0-Uninf vs Coinf (AUC = %0.2f)' % (roc_auc), color='#B6244F')

fpr, tpr, t = roc_curve(cd4_0_new_dict['inf12_true'], cd4_0_new_dict['inf12_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD4 0-CT Only vs Coinf (AUC = %0.2f)' % (roc_auc), color='#17A398')


fpr, tpr, t = roc_curve(cd4_1_new_dict['inf01_true'], cd4_1_new_dict['inf01_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD4 1-Uninf vs CT Only  (AUC = %0.2f)' % (roc_auc),linestyle = '--', color='#3D5A80')

fpr, tpr, t = roc_curve(cd4_1_new_dict['inf02_true'], cd4_1_new_dict['inf02_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD4 1-Uninf vs Coinf (AUC = %0.2f)' % (roc_auc),linestyle = '--', color='#B6244F')

fpr, tpr, t = roc_curve(cd4_1_new_dict['inf12_true'], cd4_1_new_dict['inf12_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD4 1-CT Only vs Coinf (AUC = %0.2f)' % (roc_auc),linestyle = '--', color='#17A398')

plt.plot([0,1],[0,1],linestyle = '--',lw = 2,color = 'black')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')

plt.title('', weight='bold')
# plt.legend('',loc="lower right")
plt.legend('',frameon=False)
plt.savefig("CD4_freq_inf_roc.svg", format="svg", dpi=300)
plt.show()



fig1 = plt.figure(figsize=[12,12])
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
sns.set(font_scale=2.5)
sns.set_style(style='white')

fpr, tpr, t = roc_curve(cd8_0_new_dict['inf01_true'], cd8_0_new_dict['inf01_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD8 0-Uninf vs CT Only  (AUC = %0.2f)' % (roc_auc), color='#3D5A80')

fpr, tpr, t = roc_curve(cd8_0_new_dict['inf02_true'], cd8_0_new_dict['inf02_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD8 0-Uninf vs Coinf (AUC = %0.2f)' % (roc_auc), color='#B6244F')

fpr, tpr, t = roc_curve(cd8_0_new_dict['inf12_true'], cd8_0_new_dict['inf12_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD8 0-CT Only vs Coinf (AUC = %0.2f)' % (roc_auc), color='#17A398')

fpr, tpr, t = roc_curve(cd8_1_new_dict['inf01_true'], cd8_1_new_dict['inf01_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD8 1-Uninf vs CT Only  (AUC = %0.2f)' % (roc_auc),linestyle = '--', color='#3D5A80')

fpr, tpr, t = roc_curve(cd8_1_new_dict['inf02_true'], cd8_1_new_dict['inf02_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD8 1-Uninf vs Coinf (AUC = %0.2f)' % (roc_auc),linestyle = '--', color='#B6244F')

fpr, tpr, t = roc_curve(cd8_1_new_dict['inf12_true'], cd8_1_new_dict['inf12_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD8 1-CT Only vs Coinf (AUC = %0.2f)' % (roc_auc),linestyle = '--', color='#17A398')


plt.plot([0,1],[0,1],linestyle = '--',lw = 2,color = 'black')


plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')


plt.title( '',weight='bold')
# plt.legend(loc="lower right")
plt.legend('',frameon=False)
plt.savefig("CD8_freq_inf_roc.svg", format="svg", dpi=300)
plt.show()

In [None]:
#plot freq_enfu roc curve (figure 7A)
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve,auc
import pickle

with open("CD4_v0_noH_dict.pkl", 'rb') as f:
    cd4_0_new_dict = pickle.load(f)
with open("CD8_v0_noH_dict.pkl", 'rb') as f:
    cd8_0_new_dict = pickle.load(f)
with open("CD4_v1_noH_dict.pkl", 'rb') as f:
    cd4_1_new_dict = pickle.load(f)
with open("CD8_v1_noH_dict.pkl", 'rb') as f:
    cd8_1_new_dict = pickle.load(f)


fig1 = plt.figure(figsize=[12,12])
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
sns.set(font_scale=2.5)
sns.set_style(style='white')

fpr, tpr, t = roc_curve(cd4_0_new_dict['inf1_asc01_true'], cd4_0_new_dict['inf1_asc01_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD4 0-CT Only-Endo+ vs Endo-  (AUC = %0.2f)' % (roc_auc),color='#C44F51')

fpr, tpr, t = roc_curve(cd4_0_new_dict['inf1_fol01_true'], cd4_0_new_dict['inf1_fol01_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD4 0-CT Only-FU- vs FU+ (AUC = %0.2f)' % (roc_auc),color='#55A868')

fpr, tpr, t = roc_curve(cd4_0_new_dict['inf2_asc01_true'], cd4_0_new_dict['inf2_asc01_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD4 0-Coinf-Endo+ vs Endo-  (AUC = %0.2f)' % (roc_auc),color='#DD8453')

fpr, tpr, t = roc_curve(cd4_0_new_dict['inf2_fol01_true'], cd4_0_new_dict['inf2_fol01_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD4 0-Coinf-FU- vs FU+ (AUC = %0.2f)' % (roc_auc),color='#4C72B0') #

fpr, tpr, t = roc_curve(cd4_1_new_dict['inf1_fol01_true'], cd4_1_new_dict['inf1_fol01_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD4 1-CT Only-FU- vs FU+ (AUC = %0.2f)' % (roc_auc),linestyle = '--',color='#55A868')

fpr, tpr, t = roc_curve(cd4_1_new_dict['inf2_fol01_true'], cd4_1_new_dict['inf2_fol01_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD4 1-Coinf-FU- vs FU+ (AUC = %0.2f)' % (roc_auc),linestyle = '--',color='#4C72B0') #


plt.plot([0,1],[0,1],linestyle = '--',lw = 2,color = 'black')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')

plt.title('', weight='bold')
# plt.legend(loc="lower right")
plt.legend('',frameon=False)
plt.savefig("CD4_freq_enfu_roc.svg", format="svg", dpi=300)
plt.show()


fig1 = plt.figure(figsize=[12,12])
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
sns.set(font_scale=2.5)
sns.set_style(style='white')

fpr, tpr, t = roc_curve(cd8_0_new_dict['inf1_asc01_true'], cd8_0_new_dict['inf1_asc01_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD8 0-CT Only-Endo+ vs Endo-  (AUC = %0.2f)' % (roc_auc),color='#C44F51')

fpr, tpr, t = roc_curve(cd8_0_new_dict['inf1_fol01_true'], cd8_0_new_dict['inf1_fol01_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD8 0-CT Only-FU- vs FU+ (AUC = %0.2f)' % (roc_auc),color='#55A868')

fpr, tpr, t = roc_curve(cd8_0_new_dict['inf2_asc01_true'], cd8_0_new_dict['inf2_asc01_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD8 0-Coinf-Endo+ vs Endo-  (AUC = %0.2f)' % (roc_auc),color='#DD8453')

fpr, tpr, t = roc_curve(cd8_0_new_dict['inf2_fol01_true'], cd8_0_new_dict['inf2_fol01_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD8 0-Coinf-FU- vs FU+ (AUC = %0.2f)' % (roc_auc),color='#4C72B0') #

fpr, tpr, t = roc_curve(cd8_1_new_dict['inf1_fol01_true'], cd8_1_new_dict['inf1_fol01_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD8 1-CT Only-FU- vs FU+ (AUC = %0.2f)' % (roc_auc),linestyle = '--',color='#55A868')

fpr, tpr, t = roc_curve(cd8_1_new_dict['inf2_fol01_true'], cd8_1_new_dict['inf2_fol01_pred'])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=8, alpha=1, label='CD8 1-Coinf-FU- vs FU+ (AUC = %0.2f)' % (roc_auc),linestyle = '--',color='#4C72B0') #


plt.plot([0,1],[0,1],linestyle = '--',lw = 2,color = 'black')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')

plt.title('', weight='bold')
# plt.legend('',loc="lower right")
plt.legend('',frameon=False)
plt.savefig("CD8_freq_enfu_roc.svg", format="svg", dpi=300)
plt.show()
