## Mutagenesis analysis for binary annotations

In [None]:
import autoreload
%load_ext autoreload
%autoreload 1

In [None]:
import numpy as np
import pandas as pd
import pandas.api.types as pdtypes
import glob
import pickle
import os
from plotnine import *
import utils

### Generating the aggregated csv file

  * You first need to run `mutate.snakefile` to be able to run the below script. Snakemake pipeline will generate 11 directories (e.g. `annot_4`... `annot_14`) corresponding to 11 binary annotations, each containing mutagenesis difference scores for 6 repeats of the model. The below script will aggregate these scores across repeats. 

  * You can skip the below script and directly jump to `Feature importance analysis` part if you just want to re-generate the plots in the paper from the pre-computed  `per_pheno_importance.csv` file. 


In [None]:
# provide your own path 
path_to_deeprvat_analysis = ''
exp_dir = f'{path_to_deeprvat_analysis}/deeprvat-analysis/feature_importance/mutagenesis'

if not os.path.exists(f'{exp_dir}/per_pheno_importance.csv'):
    binary_annots = pd.DataFrame()
    for annot_dir in glob.glob(f'{exp_dir}/annot*/'):
        annot_code = os.path.basename(annot_dir[:-1])

        df_results = pd.DataFrame()
        for i, repeat_f in enumerate(glob.glob(f'{annot_dir}/*.pkl')):
            with open(repeat_f, 'rb') as f:
                dict_pheno_results = pickle.load(f)
                pheno_results = [val[0] 
                                   for val in list(dict_pheno_results.values())]
                df_results[f'repeat_{i}'] = pheno_results
                df_results.index = list(dict_pheno_results.keys())

        average= np.mean(df_results[df_results.columns], axis=1)
        df_results[utils.BINARY_ANNOTATIONS[annot_code]] = average
        binary_annots = pd.concat([binary_annots, 
                            df_results[utils.BINARY_ANNOTATIONS[annot_code]]], axis=1)
    binary_annots.to_csv(f'{exp_dir}/per_pheno_importance.csv')

### Feature importance analysis

In [None]:
binary_annots = pd.read_csv(f'{exp_dir}/per_pheno_importance.csv')
binary_annots = binary_annots.set_index(list(binary_annots)[0])

In [None]:
across_pheno_avg = binary_annots.mean(axis=0).to_frame().T
annots_relative = across_pheno_avg.div(across_pheno_avg.max(axis=1), axis=0).T
annots_relative = annots_relative.rename(columns={0: 'relative_importance'})
annots_relative = annots_relative.merge(across_pheno_avg.T, 
                                        left_index=True, right_index=True)
annots_relative = annots_relative.rename(columns={0: 'value_importance'})
annots_relative['annotation'] = annots_relative.index
annots_relative = utils.add_plot_helper_columns(annots_relative)

In [None]:
## aggregating individual annotations to categories
annots_category = annots_relative.groupby(['anno_category']).mean()
annots_category['relative_importance'] = annots_category['value_importance'].div(
                            annots_category['value_importance'].max(), axis=0)
annots_category['category'] = annots_category.index

#### Plots

In [None]:
plot_font = element_text(size = 8, family = "Helvetica")

In [None]:
p = ( ggplot(annots_relative, 
                   aes(x='reorder(anno_print, anno_code)', y='relative_importance', 
                           fill='annotation'))
            + geom_bar(stat = 'identity')
            + theme_classic()
            + theme(axis_text_x=element_text(rotation=45, hjust=1),
                    #text=element_text(size=20),
                    figure_size=(10, 8),
                    text = plot_font,
                    axis_text = plot_font,
                    axis_title = plot_font,
                    legend_title=element_blank(),
                    legend_position='top'
                   )
            + scale_fill_manual(values=utils.BINARY_ANNO_COLOR_INDV_LIST, 
                               breaks = list(utils.BINARY_ANNOTATION_CODES.keys()),
                               labels = list(utils.BINARY_ANNOTATION_CATEGORIES.values()),
                               )
            + labs(x='Binary annotation', y='Relative |absolute difference|')
        )
ggsave(plot=p, filename=f'{exp_dir}/binary_importance_by_annotation.pdf', 
                   limitsize=False, verbose = False)

#### Aggregate by annotation group

In [None]:
p = (ggplot(annots_category, 
                   aes(x='reorder(category, anno_code)', y='relative_importance', 
                           fill='category'))
            + geom_bar(stat = 'identity')
            + theme_classic()
            + theme(axis_text_x=element_text(rotation=45, hjust=1),
                    #text=element_text(size=20),
                    figure_size=(10, 8),
                    text = plot_font,
                    axis_text = plot_font,
                    axis_title = plot_font,
                    legend_title=element_blank(),
                    legend_position='none'
                   )
            + scale_fill_manual(values=utils.BINARY_ANNO_COLOR_GROUP_LIST, 
                               breaks = ['Protein function', 'pLof', 'Splicing', 'Inframe indels'],
                               labels = ['Protein function', 'pLof', 'Splicing', 'Inframe indels'],
                               #guide=True
                               )
            + labs(x='Binary annotation groups',
                   y='Relative |average absolute difference per category|')
        )
ggsave(plot=p, filename=f'{exp_dir}/binary_importance_by_category.pdf', 
                   limitsize=False, verbose = False)

In [None]:
# the end. 