Variables to be updated/configured:

In [None]:
WES = False # False if running for the larger epilepsy-autism multiplex network, True if running for the WES multiplex network

if WES:
    FIGURES_DIR = "figures_wes" # path to directory where figures will the saved (creates the directory if it doesn't exist)
    COMS_DIR = "communities_wes" # path to directory containing information on the communities in the network
else:
    FIGURES_DIR = "figures" # path to directory where figures will the saved (creates the directory if it doesn't exist)
    COMS_DIR = "communities" # path to directory containing information on the communities in the network
    
GRAPH_DIR = "gexf_files" # path to directory where the .gexf files are located
INFO_DIR = "network_info" # path to directory with information on each gene/node in the multiplex network (creates the directory if it doesn't exist)
GENE_SETS_DIR = "gene_sets" # path to directory containing .csv files with genes sets

# Setup

In [None]:
# network packages
import networkx as nx
from networkx.readwrite.gexf import read_gexf

# visualization packages
import matplotlib
import matplotlib.pyplot as plt
from matplotlib_venn import venn2
import seaborn as sns

# other packages
from collections import Counter
import os
import numpy as np
import pandas as pd
from statsmodels.stats.multitest import fdrcorrection

In [None]:
if not os.path.exists(FIGURES_DIR):
    os.makedirs(FIGURES_DIR)
    
if not os.path.exists(INFO_DIR):
    os.makedirs(INFO_DIR)

In [None]:
# plotting settings
font = {'size': 18}
matplotlib.rc('font', **font)

In [None]:
if WES:
    gene_phenotype_filename = 'gene-phenotype-wes-1-500-update.gexf'
    gene_ppi_filename = "gene-ppi-wes-700-update.gexf"
    gene_union_filename = 'gene-union-wes.gexf'
    info_filename = 'info_wes_df.csv'
else:
    gene_phenotype_filename = 'gene-phenotype-1-1000-update.gexf'
    gene_ppi_filename = 'gene-ppi-700-update.gexf'
    gene_union_filename = 'gene-union.gexf'
    info_filename = 'info_df.csv'
    
gene_phenotype = read_gexf(os.path.join(GRAPH_DIR, gene_phenotype_filename))
gene_ppi = read_gexf(os.path.join(GRAPH_DIR, gene_ppi_filename))
gene_union = read_gexf(os.path.join(GRAPH_DIR, gene_union_filename))

In [None]:
# wrapper for communities
class Coms:
    def __init__(self, communities):
        self.communities = communities
        self.overlap = None
        
# get Coms class with genes from annotated networkx graph
def get_coms_from_graph(G):    
    max_module = max([G.nodes[node]['module'] for node in G.nodes])
    partition = []
    for i in range(max_module):
        partition.append([])
    for node in G.nodes:
        mod = G.nodes[node]['module']
        partition[mod-1] = partition[mod-1] + [node]
    coms = Coms(partition)
    return coms

# takes partition with IDs and converts to Coms class with genes
def partition_to_genes(partition):
    partition_genes = []
    for com in partition:
        com_genes = []
        for g in com:
            com_genes.append(id_to_gene[g])
        partition_genes.append(com_genes) 
    coms = Coms(list(partition_genes))
    return coms

In [None]:
coms_ppi = get_coms_from_graph(gene_ppi)
coms_phenotype = get_coms_from_graph(gene_phenotype)
coms_multiplex = get_coms_from_graph(gene_union)

In [None]:
# standardize gene names
def update_genes(df, gene_col):
    new_df = pd.DataFrame()
    for i, row in df.iterrows():
        replacements = {
            "ND1": "MT-ND1",
            "ND4": "MT-ND4",
            "TRNR1": "GFRA1",
            "CCM1": "KRIT1",
            "C19orf61": "SMG9",
            "EIF2C4": "AGO4",
            "HOXD": ["HOXD1", "HOXD3", "HOXD4", "HOXD8", "HOXD9", "HOXD10", "HOXD11", "HOXD12", "HOXD13"],
            "ATP6": "MT-ATP6",
            "APOE4": "APOE",
            "ENSG00000173575": "CHD2",
            "SCA2": "ATXN2",
            "B3GNT1": "B4GAT1",
            "COX3": "MT-CO3",
            "ENSG00000086848": "ALG9",
            "ATP8": "MT-ATP8",
            "ND5": "MT-ND5",
            "C2orf25": "MMADHC",
            "PIG6": "PRODH",
            "ENSG00000258947": "TUBB3",
            "ADCK3": "COQ8A", 
            "COX1": "MT-CO1",
            "DXS423E": "SMC1A",
            "PCDHG": ["PCDHGA1", "PCDHGA2", "PCDHGA3", "PCDHGA4", "PCDHGA5", "PCDHGA6", "PCDHGA7", "PCDHGA8", "PCDHGA9", "PCDHGA10", "PCDHGA11", "PCDHGA12", \
                      "PCDHGB1", "PCDHGB2", "PCDHGB3", "PCDHGB4", "PCDHGB5", "PCDHGB6", "PCDHGB7", "PCDHGC3", "PCDHGC4", "PCDHGC5"],
            "KIAA0226": "RUBCN",
            "CYTB": "MT-CYB",
            "KIAA0442": "AUTS2",
            "KAL1": "ANOS1",
            "BRP44L": "MPC1",
            "KIAA1715": "LNPK",
            "JMJD2C": "KDM4C",
            "CCDC64": "BICDL1" ,
            "KIAA2022": "NEXMIF",
            "INADL": "PATJ",
            "PIG6": "PRODH",
            "PARK2": "PRKN",
            "NDNL2": "NSMCE3",
            "BZRAP1": "TSPOAP1",
            "ERBB2IP": "ERBIN",
            "HIST1H2BJ": "H2BC11",
            "ADSS": "ADSS2",
            "C15orf43": "TERB2",
            "C16orf13": "METTL26",
            "C11orf30": "EMSY",
            "SUV420H1": "KMT5B",
            "MKL2": "MRTFB",
            "ENSG00000259159": "MFRP",
            "MARCA2":"SMARCA2",
            "C11orf82": "DDIAS",
            "CSNK2B-LY6G5B-1181": "CSNK2B",
            'TCAF1': 'FAM115A', 
            'KCNMB2': 'ENSG00000275163', 
            'KIAA1009': 'CEP162', 
            'AGMO(alsoknownasTMEM195)': 'AGMO', 
            'PPIEL': 'PPIEL', 
            'GGTA1P': 'GGTA1', 
            'KIAA1239': 'NWD2', 
            'LINC01370': 'LINC01370', 
            'PCDHA@': ['PCDHA10', 'PCDHA9', 'PCDHA5', 'PCDHA11', 'PCDHA7', 'PCDHA3', 'PCDHA8', 'PCDHA2', 'PCDHA1', 'PCDHA13', 'PCDHA4', 'PCDHA6', 'PCDHA12'], 
            'MsrA': 'MSRA', 
            'DGCR6': 'ENSG00000183628', 
            'ZNF259': 'ZPR1', 
            'ADGRA2': 'GPR124', 
            'KIAA1430': 'CFAP97',     
            'RNASE4': 'ENSG00000258818', 
            'C14orf166B': 'LRRC74A', 
            "RP11-1055B8.7": "BAHCC1",
            "ENSG00000272414": "FAM47E-STBD1",
            "C5orf20": "DCANP1",
            "SOGA2": "MTCL1",
            "FAM194A": "ERICH6"
        }
        
        gene = row[gene_col]
        if gene in replacements:
            replacement = replacements[gene]
            if isinstance(replacement, str):
                print("Replaced", gene, "with", replacement)
                row[gene_col] = replacement
                new_df = new_df.append(row)
            else:
                for j in range(len(replacement)):
                    print("Replaced", gene, "with", replacement[j])
                    row[gene_col] = replacement[j]
                    new_df = new_df.append(row)
        else:
            new_df = new_df.append(row)
                    
    return new_df

In [None]:
# color palettes for figures
col_pal_a = sns.color_palette("OrRd_r")
col_pal_a = [col_pal_a[0], [sum(x)/2 for x in zip(col_pal_a[1], col_pal_a[2])], col_pal_a[3], col_pal_a[5]]
sns.palplot(col_pal_a)
col_pal_e = sns.color_palette("GnBu_r")
col_pal_e = [col_pal_e[0], col_pal_e[2], col_pal_e[3], col_pal_e[5]]
sns.palplot(col_pal_e)

In [None]:
epilepsy_genes_df = pd.read_csv(os.path.join(GENE_SETS_DIR, "epilepsy_genes_wang_2017_formatted.csv"))
epilepsy_genes_df = update_genes(epilepsy_genes_df, 'gene')
epilepsy_genes = set(epilepsy_genes_df['gene'])

In [None]:
autism_genes_df = pd.read_csv(os.path.join(GENE_SETS_DIR, "SFARI-Gene_genes_01-03-2020release_01-05-2020export.csv"))
autism_genes_df = update_genes(autism_genes_df, 'gene-symbol')
autism_genes = set(autism_genes_df['gene-symbol'])

In [None]:
autism_wes_df = pd.read_csv(os.path.join(GENE_SETS_DIR, "WES_autism_Satterstrom_2020.csv"), nrows=102)
autism_wes_df = update_genes(autism_wes_df, "gene")
autism_wes_genes = set(autism_wes_df['gene'])

epilepsy_wes_df = pd.read_csv(os.path.join(GENE_SETS_DIR, "WES_EPI_gene_burden_AC_1_Epi25_Collaborative_2019.csv"), nrows=200, skiprows=2)
epilepsy_wes_df = update_genes(epilepsy_wes_df, 'Gene')
epilepsy_wes_genes = set(epilepsy_wes_df['Gene'])

In [None]:
# venn diagram of epilepsy- and autism-associated genes
c = venn2([epilepsy_genes, autism_genes], set_labels = ('Epilepsy-associated Genes', 'Autism-associated Genes'))
c.get_patch_by_id('10').set_color(col_pal_e[0])
c.get_patch_by_id('01').set_color(col_pal_a[0])
c.get_patch_by_id('11').set_color("purple")

for text in c.set_labels:
    text.set_fontsize(12)
for text in c.subset_labels:
    text.set_fontsize(12)
    
plt.tight_layout()
plt.savefig(FIGURES_DIR + "/venn_diagram.png", dpi=600)

# Calculate network statistics and save to file

In [None]:
# get degree and betweenness centrality of genes in network
def get_network_stats(gene_ppi, gene_phenotype, coms_multiplex, coms_ppi, coms_phenotype, groups, group_names):
    
    node_df = pd.DataFrame({"gene": list(gene_ppi.nodes)})
    node_df = node_df.sort_values(by='gene').reset_index().drop("index", axis=1)
    
    print('Calculating degree')
    # degree
    degrees = list(gene_ppi.degree(gene_ppi.nodes))
    degrees_sorted_by_gene = sorted(degrees, key=lambda x: x[0])
    degrees_sorted_by_gene_list = [i[1] for i in degrees_sorted_by_gene]
    node_df["ppi_degree"] = degrees_sorted_by_gene_list
    
    degrees = list(gene_phenotype.degree(gene_phenotype.nodes))
    degrees_sorted_by_gene = sorted(degrees, key=lambda x: x[0])
    degrees_sorted_by_gene_list = [i[1] for i in degrees_sorted_by_gene]
    node_df["phenotype_degree"] = degrees_sorted_by_gene_list
    
    print('Calculating BC')
    # betweenness centrality
    betweenness_centrality = nx.algorithms.centrality.betweenness_centrality(gene_ppi)
    betweenness_centrality_sorted_by_gene = sorted(betweenness_centrality.items(), key=lambda x: x[0])
    bc_sorted_by_gene_list = [i[1] for i in betweenness_centrality_sorted_by_gene]
    node_df['ppi_betweenness'] = bc_sorted_by_gene_list
    
    betweenness_centrality = nx.algorithms.centrality.betweenness_centrality(gene_phenotype)
    betweenness_centrality_sorted_by_gene = sorted(betweenness_centrality.items(), key=lambda x: x[0])
    bc_sorted_by_gene_list = [i[1] for i in betweenness_centrality_sorted_by_gene]
    node_df['phenotype_betweenness'] = bc_sorted_by_gene_list
    
    print('Annotating modules and groups')
    # module
    modules = []
    for mod_num, com in enumerate(coms_multiplex, 1):
        for g in com:
            modules.append([g, mod_num])
    modules_df = pd.DataFrame(modules)
    modules_df.columns = ['gene', 'module_multiplex']
    node_df = node_df.merge(modules_df, on="gene", how="left")
    
    modules = []
    for mod_num, com in enumerate(coms_ppi, 1):
        for g in com:
            modules.append([g, mod_num])
    modules_df = pd.DataFrame(modules)
    modules_df.columns = ['gene', 'module_ppi']
    node_df = node_df.merge(modules_df, on="gene", how="left")
    
    modules = []
    for mod_num, com in enumerate(coms_phenotype, 1):
        for g in com:
            modules.append([g, mod_num])
    modules_df = pd.DataFrame(modules)
    modules_df.columns = ['gene', 'module_phenotype']
    node_df = node_df.merge(modules_df, on="gene", how="left")
    
    for i in range(len(groups)):
        genes = groups[i]
        temp_df = pd.DataFrame(genes, columns=['gene'])
        temp_df[group_names[i]] = 1
        node_df = node_df.merge(temp_df, on='gene', how='left')
        
    gene_types = []
    e_groups = [e1, e2, e3, e4]
    a_groups = [a1, a2, a3, a_s]
    for n in sorted(gene_ppi.nodes):
        gtype = ""
        idx = 1
        for group in e_groups:
            if n in group:
                gtype += "Epilepsy " + str(idx) + ", "
            idx += 1
            
        idx = 1
        for group in a_groups:
            if n in group:
                if idx == 4:
                    gtype += "Autism S, "
                else:
                    gtype += "Autism " + str(idx) + ", "
            idx += 1
        
        gtype = gtype[:-2]
        gene_types.append(gtype)
    
    node_df['annotated_type'] = gene_types
    node_df = node_df.fillna(0)
    return node_df
    

In [None]:
e1 = set(epilepsy_genes_df[epilepsy_genes_df['score']==1]['gene'])
e2 = set(epilepsy_genes_df[epilepsy_genes_df['score']==2]['gene'])
e3 = set(epilepsy_genes_df[epilepsy_genes_df['score']==3]['gene'])
e4 = set(epilepsy_genes_df[epilepsy_genes_df['score']==4]['gene'])

a1 = set(autism_genes_df[autism_genes_df['gene-score']==1]['gene-symbol'])
a2 = set(autism_genes_df[autism_genes_df['gene-score']==2]['gene-symbol'])
a3 = set(autism_genes_df[autism_genes_df['gene-score']==3]['gene-symbol'])
a_s = set(autism_genes_df[autism_genes_df['syndromic']==1]['gene-symbol'])

common_genes = e1.intersection(a1)
common_all_genes = epilepsy_genes.intersection(autism_genes)
common_wes_genes = autism_wes_genes.intersection(epilepsy_wes_genes)

In [None]:
print("Number of Epilepsy 1 genes:", len(e1))
print("Number of Epilepsy 2 genes:", len(e2))
print("Number of Epilepsy 3 genes:", len(e3))
print("Number of Epilepsy 4 genes:", len(e4))
print()
print("Number of Autism 1 genes:", len(a1))
print("Number of Autism 2 genes:", len(a2))
print("Number of Autism 3 genes:", len(a3))
print("Number of Autism S genes:", len(a_s))

In [None]:
a_specific = autism_genes.difference(common_all_genes)
e_specific = epilepsy_genes.difference(common_all_genes)

groups = [epilepsy_wes_genes, e1, e2, e3, e4, autism_wes_genes, a1, a2, a3, a_s, common_wes_genes, common_genes, common_all_genes, a_specific, e_specific]
group_names = ['e_wes', 'e1', 'e2', 'e3', 'e4', 'a_wes', 'a1', 'a2', 'a3', 'as', 'common_wes', 'common', 'common_all', 'a_specific', 'e_specific'] 
info_df = get_network_stats(gene_ppi, gene_phenotype, coms_multiplex.communities, coms_ppi.communities, coms_phenotype.communities, groups, group_names)

if WES:
    info_df.to_csv(os.path.join(INFO_DIR, 'info_wes_df.csv'), index=False)
else:
    info_df.to_csv(os.path.join(INFO_DIR, 'info_df.csv'), index=False)

# Module sizes and compositions

In [None]:
info_df = pd.read_csv(os.path.join(INFO_DIR, info_filename))

In [None]:
cutoff = 30
com_ppi_sizes = [len(com) for com in coms_ppi.communities][:cutoff]
com_phenotype_sizes = [len(com) for com in coms_phenotype.communities][:cutoff]
com_multiplex_sizes = [len(com) for com in coms_multiplex.communities][:cutoff]

In [None]:
plt.figure(figsize=(16,8))
alpha = 0.5
x_ppi = np.arange(1,len(com_ppi_sizes)+1)
plt.scatter(x_ppi, com_ppi_sizes, label="PPI", alpha=alpha, s=60)

x_phenotype = np.arange(1,len(com_phenotype_sizes)+1)
plt.scatter(x_phenotype, com_phenotype_sizes, label="Phenotype", alpha=alpha, s=60)

x_multiplex = np.arange(1,len(com_multiplex_sizes)+1)
plt.scatter(x_multiplex, com_multiplex_sizes, label="Multiplex", alpha=alpha, s=60)

plt.legend()
plt.xlabel('Module number')
plt.xticks(np.arange(1,len(com_ppi_sizes)+1))
plt.ylabel('Module size')
plt.savefig(os.path.join(FIGURES_DIR, "module_sizes.png"), dpi=600)
plt.show()

In [None]:
def get_counts(df):
    d = [0] * 11
    for i, row in df.iterrows():
        if row['common'] == 1:
            d[0] += 1
        elif row['common_all'] == 1:
            d[1] += 1
        elif row['e1'] == 1:
            d[2] += 1
        elif row['e2'] == 1:
            d[3] += 1
        elif row['e3'] == 1:
            d[4] += 1
        elif row['e4'] == 1:
            d[5] += 1
        elif row['a1'] == 1:
            d[6] += 1
        elif row['a2'] == 1:
            d[7] += 1
        elif row['a3'] == 1:
            d[8] += 1
        elif row['as'] == 1:
            d[9] += 1
        else:
            if not WES:
                raise Exception('unknown', row)
            else:
                d[10] += 1
    return d

In [None]:
if WES:
    top_modules = 9
else:
    top_modules = 14
    
if WES:
    fig, ax = plt.subplots(nrows=3, ncols=3, num=top_modules, figsize=(20, 25))
else:
    fig, ax = plt.subplots(nrows=4, ncols=4, num=top_modules, figsize=(20, 25))
    ax[-1, -1].axis('off')
    ax[-1, -2].axis('off')

font = {'size'   : 14}
matplotlib.rc('font', **font)
    
mod_num = 1
for row in ax:
    for col in row:
        if mod_num > top_modules:
            continue
        mod_df = info_df[info_df['module_multiplex']==mod_num]
        
        sizes = get_counts(mod_df)
        if WES:
            colors = ['purple', 'mediumorchid'] + col_pal_e + col_pal_a + ['grey']      
            labels = ['common', 'common_all', 'e1', 'e2', 'e3', 'e4', 'a1', 'a2', 'a3', 'as', 'other']
        else:
            colors = ['purple', 'mediumorchid'] + col_pal_e + col_pal_a   
            labels = ['common', 'common_all', 'e1', 'e2', 'e3', 'e4', 'a1', 'a2', 'a3', 'as']
        
        total = sum(sizes) 
        col.pie(sizes, colors=colors, textprops={'fontsize': 12}, pctdistance=0.8, autopct=lambda p: '{:.0f}'.format(p * total / 100) if p > 0 else "")
        col.axis('equal')
        col.set_title('Module ' + str(mod_num))
    
        mod_num +=1    
        
plt.figlegend(['Common (HC)', 'Common (All)', 'Epilepsy 1', 'Epilepsy 2', 'Epilepsy 3', 'Epilepsy 4', 'Autism 1', 'Autism 2', 'Autism 3', 'Autism S', 'Other'])
plt.savefig(os.path.join(FIGURES_DIR, "module_compositions.eps"), dpi=600)    
plt.show()

# Gene enrichment in multiplex network

In [None]:
def get_enrichments_matrix(df):
    
    if WES:
        plt_labels = ["common_genes (WES)", "common_genes (HC)", "common_genes (all)", "epilepsy_WES_genes", "epilepsy_1_genes", "epilepsy_2_genes", "epilepsy_3_genes", "epilepsy_4_genes", "autism_WES_genes", "autism_1_genes", "autism_2_genes", "autism_3_genes", "autism_s_genes", "schizophrenia_genes", "BD_genes", "ID_genes", "BE_genes"]
        labels = ["common_wes_genes", "common_genes", "common_genes_all", "epilepsy_WES_genes", "e1_genes", "e2_genes", "e3_genes", "e4_genes", "autism_WES_genes", "a1_genes", "a2_genes", "a3_genes", "as_genes", "schizophrenia_genes", "BD_genes", "ID_genes", "DE_genes"]
    else:
        plt_labels = ["common_genes (WES)", "common_genes (HC)", "common_genes (all)", "epilepsy_WES_genes", "epilepsy_1_genes", "epilepsy_2_genes", "epilepsy_3_genes", "epilepsy_4_genes", "autism_WES_genes", "autism_1_genes", "autism_2_genes", "autism_3_genes", "autism_s_genes", "schizophrenia_genes", "BD_genes", "ID_genes", "BE_genes"]
        labels = ["common_wes_genes", "common_genes", "common_genes_all", "epilepsy_WES_genes", "e1_genes", "e2_genes", "e3_genes", "e4_genes", "autism_WES_genes", "a1_genes", "a2_genes", "a3_genes", "as_genes", "schizophrenia_genes", "BD_genes", "ID_genes", "DE_genes"]
      
    enrichments = df[(~df['label'].str.contains('HP:'))&(df['label'].isin(labels))]
    fdr_list = []
    for mod_num in range(1, max(df['module'])+1):
        rejected, fdr = fdrcorrection(list(enrichments[enrichments['module']==mod_num]['pval'])) # FDR correction
        fdr_list = fdr_list + list(fdr)
    enrichments['p_adjusted'] = fdr_list
    enrichments['neg_log_pval'] = -np.log10(enrichments['p_adjusted'])
    temp = []
    for label in labels:
        pvals = list(enrichments[enrichments['label']==label].sort_values(by='module')['neg_log_pval'])
        temp.append(pvals)

    plt_labels = [label.upper().replace("_", " ") for label in plt_labels]
    enrichment_df = pd.DataFrame(temp)
    enrichment_df.index = plt_labels
    return enrichment_df


In [None]:
def neg_log_p_val_label(x):
    if x > -np.log10(0.01):
        return "***"
    elif x > -np.log10(0.05):
        return "**"
    elif x > -np.log10(0.1):
        return "*"
    else:
        return ""

def plot_enrichment(enrichment_df, mod_sizes, filename, vmax=None):

    font = {'size'   : 14}
    matplotlib.rc('font', **font)

    num_mods = len(enrichment_df.columns)
    plt.figure(figsize=(16,8))

    xticklabels = []
    for i in range(1, num_mods+1):
        xticklabels.append(f'{str(i)}\n({mod_sizes[i-1]})')
    
    labels_df = enrichment_df.applymap(lambda x: neg_log_p_val_label(x))

    cmap = "Blues"
    ax = sns.heatmap(enrichment_df, annot=labels_df, fmt="", xticklabels = xticklabels, cbar_kws={'label': '-log10(FDR)'}, cmap=cmap, vmin=0, vmax=vmax)
    colorbar = ax.collections[0].colorbar
    
    if vmax:
        colorbar.set_ticks(np.arange(0, vmax+1, 10))
        colorbar.set_ticklabels(list(np.arange(0, vmax, 10)) + [str(vmax) + "+"])

    plt.xlabel('Module\n(size)')
    plt.ylabel('Gene group')
    plt.tight_layout()

    plt.savefig(filename, dpi=600)

    plt.show()

In [None]:
if WES:
    TOP_MODULES = 13
else:
    TOP_MODULES = 14

coms_multiplex_enrichment_df = pd.read_csv(os.path.join(COMS_DIR, 'coms_multiplex_enrichment.csv'))
enrichment_df = get_enrichments_matrix(coms_multiplex_enrichment_df)
enrichment_df = enrichment_df.iloc[:,:TOP_MODULES]
coms_multiplex_enrichment_df_all = pd.read_csv(os.path.join(COMS_DIR, 'coms_multiplex_enrichment_all_genes.csv'))
enrichment_df_all = get_enrichments_matrix(coms_multiplex_enrichment_df_all)
enrichment_df_all = enrichment_df_all.iloc[:,:TOP_MODULES]

In [None]:
plot_enrichment(enrichment_df, com_multiplex_sizes, os.path.join(FIGURES_DIR, "enrichment_analysis_multplex.png"))

In [None]:
plot_enrichment(enrichment_df_all, com_multiplex_sizes, os.path.join(FIGURES_DIR, "enrichment_analysis_multplex_all_genes.png"))

# Gene enrichment in phenotype network layer

In [None]:
if WES:
    TOP_MODULES = 12
else:
    TOP_MODULES = 18

coms_phenotype_enrichment_df = pd.read_csv(os.path.join(COMS_DIR, 'coms_phenotype_enrichment.csv'))
enrichment_df = get_enrichments_matrix(coms_phenotype_enrichment_df)
enrichment_df = enrichment_df.iloc[:,:TOP_MODULES]
coms_phenotype_enrichment_df_all = pd.read_csv(os.path.join(COMS_DIR, 'coms_phenotype_enrichment_all_genes.csv'))
enrichment_df_all = get_enrichments_matrix(coms_phenotype_enrichment_df_all)
enrichment_df_all = enrichment_df_all.iloc[:,:TOP_MODULES]

In [None]:
plot_enrichment(enrichment_df, com_phenotype_sizes, os.path.join(FIGURES_DIR, "enrichment_analysis_phenotype.png"))

In [None]:
plot_enrichment(enrichment_df_all, com_phenotype_sizes, os.path.join(FIGURES_DIR, "enrichment_analysis_phenotype_all_genes.png"), vmax=100)

# Gene enrichment in PPI network layer

In [None]:
if WES:
    TOP_MODULES = 10
else:
     TOP_MODULES = 17
        
coms_ppi_enrichment_df = pd.read_csv(os.path.join(COMS_DIR, 'coms_ppi_enrichment.csv'))
enrichment_df = get_enrichments_matrix(coms_ppi_enrichment_df)
enrichment_df = enrichment_df.iloc[:,:TOP_MODULES]

coms_ppi_enrichment_df_all = pd.read_csv(os.path.join(COMS_DIR, 'coms_ppi_enrichment_all_genes.csv'))
enrichment_df_all = get_enrichments_matrix(coms_ppi_enrichment_df_all)
enrichment_df_all = enrichment_df_all.iloc[:,:TOP_MODULES]

In [None]:
plot_enrichment(enrichment_df, com_ppi_sizes, os.path.join(FIGURES_DIR, "enrichment_analysis_ppi.png"))

In [None]:
plot_enrichment(enrichment_df_all, com_ppi_sizes, os.path.join(FIGURES_DIR, "enrichment_analysis_ppi_all_genes.png"), vmax=100)