In [11]:
import numpy as np
import pandas as pd 
import os

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import sys

def prepare_data(disease) :
    # feature_string = 'K_' or 'gi_'
    def loadData(feature_string , label_string , label_dict) :
        #read file
        
        filename = "./data/abundance_" + disease + ".txt"  
        if os.path.isfile(filename) :
            rawdata = pd.read_csv(filename , sep = '\t' , index_col=0 , header=None) 
        else :
            print("FileNotFoundError: File {} does not exist".format(filename))
            exit()

        # select rows having feature index identifier string  
        X = rawdata.loc[rawdata.index.str.contains(feature_string, regex=False)].astype('float64')

        # get class labels
        Y = rawdata.loc[label_string] #'disease'
        Y = Y.replace(label_dict).astype('int')
        
        return X , Y 

    def filter_data(x , y , filter_thresh) :
        
        classes = np.unique(y) 
        index = x.index.values  

        num_counts = {} 
        for c in classes :
            sub_x = x[y == c]
            num_samples = len(sub_x) 
            # sub_x[sub_x > 0].count()  
            num_counts[str(c)] = sub_x[sub_x > 0].count() / float(num_samples)

        core = pd.DataFrame(index=index)
        for feature in x.columns.values:
            for c in classes : 
                if(num_counts[str(c)].loc[feature] >= filter_thresh) :
                    #core[feature] = x[feature].copy()
                    core = pd.concat([core , x[feature]] , axis=1)
                    break 
        return core 

    def get_feature_df(features):
        kingdom, phylum, cl, order, family, genus, species  = [], [], [], [], [], [], []
        for f in features:

            name = f.split("k__")[1].split("|p__")[0].replace(".","")
            if "_unclassified" in name:
                name = 'unclassified_' + name.split("_unclassified")[0]
            kingdom.append(name)

            if "p__" in f:
                name =f.split("p__")[1].split("|c__")[0].replace(".","")
                if "_unclassified" in name:
                    name = 'unclassified_' + name.split("_unclassified")[0]
                if name != "":
                    phylum.append(name)
                else:
                    phylum.append("NA")
            else:
                phylum.append("NA")
                
            if "c__" in f:
                name = f.split("c__")[1].split("|o__")[0].replace(".","")
                if "_unclassified" in name:
                    name = 'unclassified_' + name.split("_unclassified")[0]
                if name != "":
                    cl.append(name)
                else:
                    cl.append("NA")
            else:
                cl.append("NA")
                
            if "o__" in f:
                name = f.split("o__")[1].split("|f__")[0].replace(".","")
                if "_unclassified" in name:
                    name = 'unclassified_' + name.split("_unclassified")[0]
                if name != "":
                    order.append(name)
                else:
                    order.append("NA")
            else:
                order.append("NA")
                
            if "f__" in f:
                name = f.split("f__")[1].split("|g__")[0].replace(".","")
                if "_unclassified" in name:
                    name = 'unclassified_' + name.split("_unclassified")[0]
                if name != "":
                    family.append(name)
                else:
                    family.append("NA")
            else:
                family.append("NA")
                
            if "g__" in f:
                name = f.split("g__")[1].split("|s__")[0].replace(".","")
                if "_unclassified" in name:
                    name = 'unclassified_' + name.split("_unclassified")[0]
                if name != "":
                    genus.append(name)
                else:
                    genus.append("NA")
            else:
                genus.append("NA")
                
            if "s__" in f:
                name = f.split("s__")[1]
                if "_unclassified" in name:
                    name = 'unclassified_' + name.split("_unclassified")[0]
                if name != "":
                    species.append(name)
                else:
                    species.append("NA")
            else:
                species.append("NA")
                
        if len(species) == 0:
            d = {'kingdom': kingdom, 'phylum': phylum, 'class':cl,
                'order':order, 'family':family, 'genus':genus}
            feature_df = pd.DataFrame(data=d)
            feature_df.index = feature_df['genus']
        else:
            d = {'kingdom': kingdom, 'phylum': phylum, 'class':cl,
                'order':order, 'family':family, 'genus':genus, 'species': species}
            feature_df = pd.DataFrame(data=d)
            feature_df.index = feature_df['species']
        return feature_df

    feature_string = 'k__'
    label_string = 'disease'
    label_dict = {
        # Controls
        'n': 0,
        # Cirrhosis
        'cirrhosis': 1, 
        # T2D and WT2D
        't2d': 1,
        # Obesity
        'leaness': 0, 'obesity': 1,
    }

    Raw_X_data , labels = loadData(feature_string , label_string , label_dict )
    Raw_X_data = Raw_X_data.transpose() 
    labels = labels.values 
    filter_X_data = filter_data(Raw_X_data , labels , 0.2)
    features = list(filter_X_data.columns.values)
    features_df = get_feature_df(features)  
    # print("samples are %d , Raw features are %d ..." % (Raw_X_data.shape[0] ,  Raw_X_data.shape[1]))  
    # print("filter data after samples are %d , filter Raw features are %d ..." % (filter_X_data.shape[0] ,  filter_X_data.shape[1])) 

    return filter_X_data , labels , features_df

def plot_figure(ax , top20_plot_data_df , disease) :
    y_pos = np.arange(len(top20_plot_data_df.index))  
    ax.barh(y_pos , top20_plot_data_df["health"].values , height=0.5, color='green', ecolor='green', edgecolor='none')
    ax.barh(y_pos + 0.5 , top20_plot_data_df["disease"].values , height=0.5, color='crimson', ecolor='crimson', edgecolor='none')

    ax.tick_params(labelsize=12, axis='x')
    ax.locator_params(nbins=5, axis='x')
    ax.set_xscale('log')
    ax.set_xlabel('Healthy (in green) and diseased (in red)\n average relative abundance [%]', size=12)
    
    ax.set_yticks(y_pos + 0.18)
    # ax.set_yticklabels(top20_plot_data_df.index)
    y_lable_name = []
    for name in top20_plot_data_df.index : 
        if top20_plot_data_df.loc[name][0] < 0 : 
            y_lable_name.append(name + '(—)')
        else : 
            y_lable_name.append(name + '(+)')
    ax.set_yticklabels(y_lable_name) 
    # [s.set_style('oblique') for s in ax.yaxis.get_ticklabels()]
    ax.tick_params(labelsize=12, axis='y')
    ax.invert_yaxis()
    ax.set_title(disease + " dataset")
 
fig, ax = plt.subplots(1, 3, figsize=(22, 10)) 
Diseases = ["Cirrhosis" , "T2D" , "Obesity"]
for i in range(len(Diseases)) : 
    disease = Diseases[i]
    filter_X_data , labels , features_df = prepare_data(disease)
    features_name = features_df.index.tolist()
    for idx in range(len(features_name)) : 
        if features_name[idx] == 'NA' :  
            features_name[idx] = features_df.iloc[idx]['genus']
        if 'unclassified' in features_name[idx] :
            features_name[idx] = features_name[idx][13:]  + ' spp.'

    Shap_df_cv = pd.read_csv('Shap_df_cv_' + disease +'_mean_10.csv' , index_col = 0 )

    # averaging
    # mean_shap_values = Shap_df_cv.mean(axis=1)
    # Take the average and create a new DataFrame
    SHAP_mean_df = pd.DataFrame(Shap_df_cv.mean(axis=1 , numeric_only=True), columns=['SHAP'])

    # Obtain the relative abundance values corresponding to the species
    label_dfy = pd.DataFrame(labels, columns=['group'])
    tmp_data_df = filter_X_data.join(label_dfy)
    abundance_mean_df = tmp_data_df.groupby('group').mean().transpose()
    abundance_mean_df.columns = ["health" , "disease"] 
    abundance_mean_df.index = features_name
    SHAP_mean_df.index = features_name 
    # Obtain species based on the top 20 SHAP values
    SHAP_Abundance_mean_df = SHAP_mean_df.join(abundance_mean_df) 
    plot_data_df = SHAP_Abundance_mean_df.iloc[SHAP_Abundance_mean_df['SHAP'].abs().argsort()[::-1]]
    top20_plot_data_df = plot_data_df.head(20)
    print(disease)
    plot_figure(ax[i] , top20_plot_data_df , disease)
 
plt.tight_layout()
 
plt.show()

plt.savefig("top20.png")

Cirrhosis
T2D
Obesity


  plt.show()
