### Annotation analysis

In [1]:
import autoreload
%load_ext autoreload
%autoreload 1

In [2]:
import numpy as np
import pandas as pd
import pandas.api.types as pdtypes
import glob
import pickle
import os
from plotnine import *
import utils

#### Reading raw results

In [None]:
# skip the following cell if you already have per_pheno_importance.csv

In [3]:
# provide the path of raw files to [experiment_dir] generate csv file 
# ~/rvat/multipheno_feature_importance/mutagenesis
experiments_dir = '~/genopheno/experiments/explain_mutagenesis/'
if not os.path.exists('per_pheno_importance.csv'):
    binary_annots = pd.DataFrame()
    for annot_dir in glob.glob(f'{experiments_dir}*/'):
        if ('__pycache__' not in annot_dir) and ('logs' not in annot_dir):
            annot_code = os.path.basename(annot_dir[:-1])

            df_results = pd.DataFrame()
            for i, repeat_f in enumerate(glob.glob(f'{annot_dir}/*.pkl')):
                with open(repeat_f, 'rb') as f:
                    dict_pheno_results = pickle.load(f)
                    pheno_results = [val[0] 
                                   for val in list(dict_pheno_results.values())]
                    df_results[f'repeat_{i}'] = pheno_results
                    df_results.index = list(dict_pheno_results.keys())

            average= np.mean(df_results[df_results.columns], axis=1)
            df_results[utils.BINARY_ANNOTATIONS[annot_code]] = average
            binary_annots = pd.concat([binary_annots, 
                            df_results[utils.BINARY_ANNOTATIONS[annot_code]]], axis=1)
    binary_annots.to_csv('per_pheno_importance.csv')

### Feature importance plots

In [19]:
binary_annots = pd.read_csv('per_pheno_importance.csv')
binary_annots = binary_annots.set_index(list(binary_annots)[0])

In [7]:
across_pheno_avg = binary_annots.mean(axis=0).to_frame().T
annots_relative = across_pheno_avg.div(across_pheno_avg.max(axis=1), axis=0).T
annots_relative = annots_relative.rename(columns={0: 'relative_importance'})
annots_relative = annots_relative.merge(across_pheno_avg.T, 
                                        left_index=True, right_index=True)
annots_relative = annots_relative.rename(columns={0: 'value_importance'})
annots_relative['annotation'] = annots_relative.index
annots_relative = utils.add_plot_helper_columns(annots_relative)

In [9]:
## aggregating individual annotations to categories
annots_category = annots_relative.groupby(['anno_category']).mean()
annots_category['relative_importance'] = annots_category['value_importance'].div(
                            annots_category['value_importance'].max(), axis=0)
annots_category['category'] = annots_category.index

#### Plots

In [13]:
plot_font = element_text(size = 8, family = "Helvetica")

In [16]:
p = ( ggplot(annots_relative, 
                   aes(x='reorder(anno_print, anno_code)', y='relative_importance', 
                           fill='annotation'))
            + geom_bar(stat = 'identity')
            + theme_classic()
            + theme(axis_text_x=element_text(rotation=45, hjust=1),
                    #text=element_text(size=20),
                    figure_size=(10, 8),
                    text = plot_font,
                    axis_text = plot_font,
                    axis_title = plot_font,
                    legend_title=element_blank(),
                    legend_position='top'
                   )
            + scale_fill_manual(values=utils.BINARY_ANNO_COLOR_INDV_LIST, 
                               breaks = list(utils.BINARY_ANNOTATION_CODES.keys()),
                               labels = list(utils.BINARY_ANNOTATION_CATEGORIES.values()),
                               )
            + labs(x='Binary annotation', y='Relative |absolute difference|')
        )
ggsave(plot=p, filename='binary_importance_by_annotation.pdf', 
                   limitsize=False, verbose = False)

#### aggregate by annotation group

In [18]:
p = (ggplot(annots_category, 
                   aes(x='reorder(category, anno_code)', y='relative_importance', 
                           fill='category'))
            + geom_bar(stat = 'identity')
            + theme_classic()
            + theme(axis_text_x=element_text(rotation=45, hjust=1),
                    #text=element_text(size=20),
                    figure_size=(10, 8),
                    text = plot_font,
                    axis_text = plot_font,
                    axis_title = plot_font,
                    legend_title=element_blank(),
                    legend_position='none'
                   )
            + scale_fill_manual(values=utils.BINARY_ANNO_COLOR_GROUP_LIST, 
                               breaks = ['Protein function', 'pLof', 'Splicing', 'Inframe indels'],
                               labels = ['Protein function', 'pLof', 'Splicing', 'Inframe indels'],
                               #guide=True
                               )
            + labs(x='Binary annotation groups',
                   y='Relative |average absolute difference per category|')
        )
ggsave(plot=p, filename='binary_importance_by_category.pdf', 
                   limitsize=False, verbose = False)