## SHAP-based feature importance plots

In [None]:
import autoreload
%load_ext autoreload
%autoreload 1
%matplotlib inline

In [None]:
import numpy as np
import pandas as pd
import os
from plotnine import *
import glob
import yaml
import pickle
import utils

In [None]:
annotation_names = utils.ANNOTATION_NAMES
binary_annots_to_remove = list(utils.BINARY_ANNOTATION_CATEGORIES.keys())

### Generating aggreated raw files 

  * You need to run `explain.snakefile` to be able to run the below script. Snakemake pipeline will generate 15 directories (e.g. `sample_1`... `sample_15`) corresponding to different train and test samplings used for SHAP, each containing SHAP scores for 6 repeats of the model. The below script will aggregate these scores across repeats and samplings. 

  * You can skip the below script and directly jump to `Feature importance analysis` part if you just want to re-generate the plots in the paper from the pre-computed & pre-aggregated `.pkl` files. 


In [None]:
importance_by_sampling = {}
rel_importance_by_sampling = {}
# your own path
path_to_deeprvat_analysis = ''
exp_dir = f'{path_to_deeprvat_analysis}/deeprvat-analysis/feature_importance/shap'
if not os.path.exists(f'{exp_dir}/shap_score_per_sampling.pkl'):
    for sampling_dir in glob.glob(f'{exp_dir}/sample*/'):
        sample_no = os.path.basename(sampling_dir[:-1])
        per_repeat_importance = utils.collect_repeats_in_one_sampling(sampling_dir, 
                                                              annotation_names)
        agg_importance = utils.agg_over_repeats(per_repeat_importance)
        agg_importance_arr = np.array([v for v in agg_importance.values()])
        agg_over_pheno = np.mean(agg_importance_arr, axis=0)
        importance_by_sampling[sample_no] = agg_over_pheno
        rel_importance_by_sampling[sample_no] = agg_over_pheno / np.max(agg_over_pheno)

    
    with open(f'{exp_dir}/shap_score_per_sampling.pkl', 'wb') as f:
        pickle.dump(importance_by_sampling, f)
        
    with open(f'{exp_dir}/shap_relative_per_sampling.pkl', 'wb') as f:
        pickle.dump(rel_importance_by_sampling, f)


### Feature importance analysis.

In [None]:
# Generates plot Supp. Fig. 3.7 & 3.8 

In [None]:
with open(f'{exp_dir}/shap_score_per_sampling.pkl', "rb") as f:
    importance_by_sampling = pickle.load(f)

per_sample_df = pd.DataFrame(importance_by_sampling)
agg_over_samples = per_sample_df.mean(axis=1)
agg_importance_df = pd.DataFrame({'importance': agg_over_samples,
                                  'annotation': annotation_names})
agg_importance_df['importance'] = agg_importance_df.importance / np.max(
                                                 agg_importance_df.importance )
binary_mask = agg_importance_df.annotation.isin(binary_annots_to_remove)
agg_importance_df = agg_importance_df.loc[~binary_mask, ].reset_index(drop=True)
agg_importance_df = utils.add_plot_helper_columns(agg_importance_df)
agg_importance_df.head()

In [None]:
plot_font = element_text(size = 8, family = "Helvetica")

p1 = ( ggplot(agg_importance_df, aes(x='reorder(anno_print, anno_code)', 
                                             y='importance',  
                                             fill='annotation'))
            + geom_bar(stat = 'identity')
            #+ facet_wrap('variable', nrow = 3 )
            + theme_classic()
            + theme(subplots_adjust={'wspace': 0.1, 'hspace': 0.25}, 
                    axis_text_x=element_text(rotation=45, hjust=1),
                    #axis_text_x=element_blank(),
                    #text=element_text(size=30),
                    text = plot_font,
                    axis_text = plot_font,
                    axis_title = plot_font,
                    legend_position='top',
                    figure_size=(10, 8)
                   )
            + scale_fill_manual(values=utils.QUANT_COLOR_LIST, 
                               breaks = list(utils.ANNOTATION_CODES.keys()),
                               labels = list(utils.ANNOTATION_CATEGORIES.values())
                               )
            + labs(x='Quantitative annotation', y='Relative |SHAP importance|')
        )
ggsave(plot=p1, limitsize=False, verbose = False,
       filename=f'{exp_dir}/quantitative_importance_individual.pdf')

### Plot relative importance per sampling 

In [None]:
with open(f'{exp_dir}/shap_relative_per_sampling.pkl', "rb") as f:
    rel_importance_by_sampling = pickle.load(f)
    
per_sample_df = pd.DataFrame(rel_importance_by_sampling)
per_sample_df['annotation'] = annotation_names
binary_mask = per_sample_df.annotation.isin(binary_annots_to_remove)
per_sample_df = per_sample_df.loc[~binary_mask, ].reset_index(drop=True)

## melt to plot
melted_per_sample_df = pd.melt(per_sample_df, id_vars=['annotation'], 
        value_vars=list(set(per_sample_df.columns)-set(['annotation'])))

## columns to help plots
melted_per_sample_df = utils.add_plot_helper_columns(melted_per_sample_df)
melted_per_sample_df.head()

In [None]:
p = ( ggplot(melted_per_sample_df)
            + geom_boxplot(aes(x='reorder(anno_print, anno_code)', y='value', 
                           fill='annotation')) ## variable # fill='variable'
            + theme_classic()
            #+ facet_wrap('variable')
            + theme(axis_text_x=element_text(rotation=45, hjust=1),
                    #text=element_text(size=20),
                    figure_size=(10, 8),
                    text = plot_font,
                    axis_text = plot_font,
                    axis_title = plot_font,
                    legend_title=element_blank(),
                    legend_position='none'
                   )
            + scale_fill_manual(values=utils.QUANT_COLOR_LIST, 
                               breaks = list(utils.ANNOTATION_CODES.keys()),
                               labels = list(utils.ANNOTATION_CATEGORIES.values()),
                               #guide=True
                               )
            + labs(x='Annotation', y='Relative |SHAP importance|')
)
ggsave(plot=p, filename=f'{exp_dir}/quantitative_importance_per_sampling.pdf', 
                   limitsize=False, verbose = False)

In [None]:
# the end.