In [None]:
# imports

import os
import numpy as np 
import pandas as pd

In [None]:

class shap_postprocessing:
    
    def __init__(self, root_dir, infile, colfile, rowfile):
        self.root_dir = root_dir
        self.infile = infile
        os.chdir(root_dir)
        self.colfile = colfile
        self.rowfile = rowfile
                
    def get_shap_abs_means(self):

        shap_values = np.load(self.infile) # ['arr_0']
        print("Shap value shape: ", shap_values.shape)
        
        shap_means = pd.DataFrame(np.abs(shap_values).mean(axis=1))
        
        shap_means.columns = self.get_shap_col_ids()
        shap_means.index = self.get_shap_row_ids()
        
        shap_means.index.name = 'class'
        
        return shap_means
    
    def get_shap_col_ids(self):
        return [c.rstrip() for c in open(self.colfile)]
    
    def get_shap_row_ids(self):
        return [c.rstrip() for c in open(self.rowfile)]
    
        
    def get_top_n_genes(self, shap_df, outfile, n):
        print(outfile)
        
        shap_df['cls'] = shap_df.index
        meltedf = pd.melt(shap_df, id_vars=['cls'])
    
        topN_dfs = []
    
        for cls in shap_df.index.to_list():
            topN_dfs.append(meltedf[meltedf['cls'] == cls].sort_values(['cls', 'value'], ascending=False).iloc[:n,:])
    
        topN = pd.concat(topN_dfs, axis=0)
        topN.to_csv(outfile)
        
        

In [None]:
class ShapRunner:

    def __init__(self, shap_dir, outdir, colfile, rowfile):
        # folder containing Merged Shap Chunk file from all 10 models
        self.shap_dir = shap_dir
        # output folder to store top genes 
        self.outdir = outdir
        # File containing genes in the same order as model input 
        self.colfile = colfile
        # Test data sample labels in the same order as they used in testing 
        self.rowfile = rowfile
        

    def get_shap_files(self):
        print(self.shap_dir)
    
        infiles = []
    
        for f in os.listdir(self.shap_dir):
            if f.endswith('npy') and f.startswith('DNN'):
                infiles.append(os.path.join(shap_dir ,f))
                
        return infiles
    

    def go(self, n):
        
        for infileX in self.get_shap_files():
        
            print('-------------*****----------------')
            outf = os.path.join(self.outdir, os.path.basename(infileX))
            fsx = '_Top_'+str(n)+'.txt'
            topnfile = str(outf).replace('.npy',fsx)
            outfile = str(outf).replace('npy','txt')
            shap_obj = shap_postprocessing(self.shap_dir, infileX, self.colfile, self.rowfile)
            shap_df = shap_obj.get_shap_abs_means()
            shap_obj.get_top_n_genes(shap_df, topnfile, n)
    
            print('--***------------------------***--')
    



#### In order to run this program 
Please copy all the Merged Shap Chunk files into one folder.
You will need 4 inputs to run this.
1. path folder containing Merged Shap Chunk file from all 10 models (shap_dir)
2. output folder to store top genes (outdir)
3. file containing genes in the same order as model input (shap_genes)
4. test data sample labels in the same order as they used in testing (shap_labels)

In [None]:
shap_dir = '/path/to/shap_dir'
outdir = '/path/to/outputdir/'
shap_genes  = '/path/to/shap_genes'
shap_labels = '/path/to/shap/labels'

ShapRunner(shap_dir, outdir, shap_genes, shap_labels).go(20)

