In [None]:
import numpy as np
import seaborn as sns
import pandas as pd
from matplotlib.patches import Patch

In [None]:
from load_data_from_synpase import load_file, load_excel, load_table
from magine.plotting.wordcloud_tools import create_wordcloud
from magine.enrichment.enrichr import Enrichr, _valid_libs, db_types
import matplotlib.pyplot as plt
import networkx as nx
from pybeataml.load_data import AMLData

In [None]:
data = AMLData()

In [None]:
exp_data = data.exp_data

In [None]:
phosph = exp_data.phospho.pivoter(index='sample_id', columns='label', values='exp_value')
phosph.count()

In [None]:
phosph.head()

In [None]:
phosph.count().sort_values()

In [None]:
exp_data.phospho.heatmap(convert_to_log=False, index='label', annotate_sig=False);

In [None]:
exp_data.wes.heatmap(
    index='identifier', 
    convert_to_log=False, 
    num_colors=3, 
    annotate_sig=False
);

In [None]:
# will be used to map patient ids to clusters
mapping = load_file('syn26642544')
del mapping['Barcode.ID']
mapping.tail(5)

In [None]:
full_meta = load_excel('syn26532699')

In [None]:
meta =  load_file('syn25807733')
meta.set_index('Barcode.ID', inplace=True)
meta

In [None]:
important_cols = [
    'FLT3.ITD', 'InitialAMLDiagnosis',
       'PostChemotherapy'
]
to_del = ['Plex', 'Channel', 'Loading.Mass']

meta = meta.loc[:, important_cols]
meta

In [None]:
cluster_maps = meta.join(mapping, ).dropna(subset=['k=2'])
cluster_maps.reset_index(inplace=True)
cluster_maps.rename(
    {'Barcode.ID': 'sample_id'}, 
    inplace=True, 
    axis=1
)
cluster_maps

In [None]:
# focusing on k=5 and k=8 
list_of_gene_sets = [
    #'syn26718015',
    #'syn26718016',
    'syn26718017',
    #'syn26718018',
    #'syn26718019',
    'syn26718020',
]

In [None]:
# organize output, gather cluster and data type together
def get_genes_per_cluster(feature_array):
    output_dict = {}
    for i, d in feature_array.groupby(['Cluster', 'data_type'])['feature']:
        output_dict[i] = list(set(d.values))
        if i[1] == 'Phospho':
            output_dict[(i[0], 'phospho_gene')] = list(set(i.split('-')[0] for i in d.values))
    return output_dict

In [None]:
k_equal_5 = 'syn26718017'
k5 = load_file(k_equal_5)
k5_clusters = get_genes_per_cluster(k5)

k_equal_8 = 'syn26718020'
k8 = load_file(k_equal_8)
k8_clusters = get_genes_per_cluster(k8)

In [None]:
k5.groupby('Cluster')['data_type'].unique()

In [None]:
def view_cluster(data, meta_genes, subset_index='identifier'):
    
    test = data.subset(
       meta_genes, 
        index=subset_index, 
    ).pivoter(
        index='label', 
        values='exp_value'
    ).T
    plot_cols = test.columns.values
    
    test.fillna(0, inplace=True)
    
    test.reset_index(inplace=True)
    cluster_dict = cluster_maps[['sample_id', 'k=5']]
    test = test.merge(cluster_maps, on='sample_id')
    test.set_index('sample_id', inplace=True)
    
    
    node_labels = test['k=5']
    node_pal = sns.color_palette("Set2", 5)
    node_lut = dict(zip(sorted(node_labels.unique()), node_pal))
    node_colors = pd.Series(node_labels, index=test.index.values, name='Cluster').map(node_lut)


    node2_labels = test['FLT3.ITD']
    node2_pal = sns.color_palette("binary", len(node2_labels.unique()))
    node2_lut = dict(zip(sorted(node2_labels.unique()), node2_pal))
    node2_colors = pd.Series(node2_labels, index=test.index.values, name='FLT3.ITD').map(node2_lut)

    node3_labels = test['PostChemotherapy']
    node3_pal = sns.color_palette("Dark2", len(node3_labels.unique()))
    node3_lut = dict(zip(sorted(node3_labels.unique()), node3_pal))
    node3_colors = pd.Series(node3_labels, index=test.index.values, name='PostChemotherapy').map(node3_lut)
    
    network_node_colors = pd.concat(
        [pd.DataFrame(node_colors), pd.DataFrame(node2_colors), pd.DataFrame(node3_colors)], 
        axis=1
    )
                                              
    
    figsize=(12, 18)
    g = sns.clustermap(
        test[plot_cols].T, 
        col_colors=network_node_colors, 
        cmap=sns.color_palette("coolwarm", 11),
        figsize=figsize,
        yticklabels=True
    )

    handles = [Patch(facecolor=node_lut[name]) for name in node_lut]
    leg1 = plt.legend(handles, node_lut, title='Cluster',
               bbox_to_anchor=(1, .9), bbox_transform=plt.gcf().transFigure, loc='upper right')
    
    handles2 = [Patch(facecolor=node2_lut[name]) for name in node2_lut]
    leg2 = plt.legend(handles2, node2_lut, title='FLT3.ITD',
               bbox_to_anchor=(1, .8), bbox_transform=plt.gcf().transFigure, loc='upper right')
    
    handles3 = [Patch(facecolor=node3_lut[name]) for name in node3_lut]
    leg3 = plt.legend(handles3, node3_lut, title='PostChemo',
               bbox_to_anchor=(1.1, .8), bbox_transform=plt.gcf().transFigure, loc='upper right')
    
    plt.gca().add_artist(leg1)
    plt.gca().add_artist(leg2)
    
view_cluster(exp_data.proteomics, k5_clusters[(1, 'Global')])

In [None]:
view_cluster(exp_data.phospho, k5_clusters[(1, 'Phospho')], subset_index='label')

In [None]:
view_cluster(exp_data.proteomics, k5_clusters[(1, 'Global')])  
view_cluster(exp_data.phospho, k5_clusters[(1, 'phospho_gene')])
view_cluster(exp_data.wes, k5_clusters[(1, 'WES')])
view_cluster(exp_data.rna, k5_clusters[(1, 'RNA')])

In [None]:
view_cluster(exp_data.rna, k5_clusters[(3, 'RNA')])

In [None]:
view_cluster(exp_data.wes, k5_clusters[(2, 'WES')]+k5_clusters[(1, 'WES')])

In [None]:
exp_data.proteomics.heatmap(
    k8_clusters[(1, 'Global')],
    index='label', 
    subset_index='identifier', 
    convert_to_log=False, 
#     num_colors=3, 
    annotate_sig=False,
    cluster_row=False,
    cluster_col=True,
);

In [None]:
k5.groupby(['Cluster', 'data_type']).count()['feature']

In [None]:
k8.groupby('Cluster').count()['feature']

In [None]:
k8.groupby(['Cluster', 'data_type']).count()['feature']

In [None]:
k8.groupby('Cluster')['data_type'].unique()

In [None]:
# enrichR engine. Basically use to pass a list, or list of genes, to enrichR, grab results.
# results are a MAGINE.enrichment_result, which is a fancy data class worth exploring!
e = Enrichr()

In [None]:
# prep sample names, will use cluster_data_type
k5_sample_names = [f'{cluster}_{d_type}' for cluster, d_type 
                   in k5_clusters.keys()]

k5_samples = [i for i in k5_clusters.values()]

# prep sample names, will use cluster_data_type
k8_sample_names = [f'{cluster}_{d_type}' for cluster, d_type 
                   in k8_clusters.keys()]

k8_samples = [i for i in k8_clusters.values()]

In [None]:
# run enrichment
k5_enrichment = e.run_samples(
    k5_samples, 
    k5_sample_names, 
    gene_set_lib='Reactome_2016'
)
# clean up names
k5_enrichment.term_name = k5_enrichment.term_name.str.split('_').str.get(0)

In [None]:
k5_enrichment.n_genes.hist();

In [None]:
k5_enrichment = k5_enrichment.loc[k5_enrichment.n_genes>5]

In [None]:
k5_enrichment.n_genes.hist();

In [None]:
k5_enrichment.sig.sort_values('n_genes',ascending=True).head(20)

In [None]:
k5_enrichment.remove_redundant(
    level='dataframe', 
    sort_by='combined_score'
)

In [None]:
k5_enrichment.sig.groupby('sample_id').count()['term_name']

In [None]:
# create heatmap
k5_enrichment.remove_redundant(
    level='dataframe', 
    sort_by='combined_score'
).heatmap(
    figsize=(6, 16),
    linewidths=.01,
    y_tick_labels=True,
    cluster_col=False,
    cluster_row=True
);
plt.savefig('k5_reactome_enrichment.png', bbox_inches='tight', dpi=300)

In [None]:

# run enrichment
k8_enrichment = e.run_samples(
    k8_samples, 
    k8_sample_names, 
    gene_set_lib='Reactome_2016'
)
# clean up names
k8_enrichment.term_name = k8_enrichment.term_name.str.split('_').str.get(0)

In [None]:
# create heatmap
k8_enrichment.remove_redundant(
    level='dataframe', 
    sort_by='combined_score'
).heatmap(
    figsize=(6, 16),
    linewidths=.01,
    y_tick_labels=True,
    cluster_col=False,
    cluster_row=True
);
plt.savefig('k8_reactome_enrichment.png', bbox_inches='tight', dpi=300)

# Network Exploration

Generate annotated set networks (nodes are terms, edges are information from molecular network between nodes).

In [None]:

from magine.networks.annotated_set import create_asn
from magine.networks.visualization.notebooks import view
from magine.networks.utils import delete_disconnected_network, trim_sink_source_nodes
from magine.networks.subgraphs import Subgraph

In [None]:
# Only need once, then load in

# from magine.networks.network_generator import create_background_network
# net = create_background_network('background_network')

In [None]:
net = nx.read_gpickle('background_network.p.gz')

In [None]:
sorted(k8_enrichment.sig.sample_id.unique())

In [None]:
subset_cl_4 = k8_enrichment.sig.filter_multi(
    sample_id=['1_Global',  '1_phospho_gene']
).copy()
# subset_cl_4 = subset_cl_4.loc[~subset_global_1.term_name.isin(
#     ['metabolism', 'infectious disease', 'immune system', 'disease']
# )]

subset_cl_4.remove_redundant(inplace=True, threshold=.5, sort_by='combined_score', level='dataframe' )
subset_cl_4

In [None]:
subset_global_1 = k8_enrichment.sig.filter_multi(
    sample_id=['1_Global', ]#'1_phospho_gene', '1_WES']
).copy()
subset_global_1 = subset_global_1.loc[~subset_global_1.term_name.isin(
    ['metabolism', 'infectious disease', 'immune system', 'disease', 'gene expression']
)]

subset_global_1.remove_redundant(inplace=True, threshold=.75, sort_by='rank', level='sample' )
subset_global_1

In [None]:
asn, mol_net = create_asn(
    subset_global_1,
    net, 
    remove_isolated=False,
    use_fdr=True, 
    use_threshold=False,
    min_edges=1
)
# asn = delete_disconnected_network(asn)
print(len(mol_net.nodes),len(mol_net.edges))
for i in asn.nodes:
    asn.node[i]['color'] = 'white'

In [None]:


view.draw_cyjs(asn, default_node_color='black', layout='breadthfirst', spacingFactor=1)

In [None]:
asn.node['glucose metabolism']

In [None]:
view.draw_cyjs(mol_net, add_parent=True)

In [None]:
sub = Subgraph(net)

In [None]:
expand = sub.expand_neighbors(
    mol_net, upstream=False, downstream=True, 
    include_only=k8_enrichment.sig.all_genes_from_df(), 
    add_interconnecting_edges=True
)
expand = trim_sink_source_nodes(expand)
expand = delete_disconnected_network(expand)
expand.remove_edges_from(nx.selfloop_edges(expand))
print(len(expand.nodes),len(expand.edges))

In [None]:
view.draw_cyjs(expand, add_parent=True)