In [1]:
import scanpy as sc
import numpy as np
import pandas as pd
import collections

import plotly
import seaborn as sns
import scipy.io as sci
import plotly.graph_objects as go
import chart_studio.plotly as py

  from pandas.core.index import RangeIndex


In [29]:
def CellLayers(data_df, cell_df, gene):
    
    # Create Sankey node labels
    # Parse louvain resolution columns
    resolution_list = [column for column in list(data_df) if column.split('_')[0] == 'integrated']
    node_labels = []
    for res in resolution_list:
        res_n_clusters = list(map(( lambda x: res.split('_')[-1] + '_' + str(x)), list(set(data_df[res]))))
        node_labels.append(res_n_clusters)    
    # plotly node labels (I will need the index to line up the source, target, and values)
    label = [item for sublist in node_labels for item in sublist]
#     print(label)

    # Set categorical cluster resolutions as categorical variables
    for res in resolution_list:
        data_df[res] = [str(x) for x in data_df[res]]
    
    
    # Create a dict of louvain resolution directed acyclic graphs
    dag_dict = dict(collections.Counter(tuple(dag) for dag in data_df[resolution_list].values))

    # Find most frequently walked louvain resolution path
    print('Stable cell path', max(dag_dict, key=dag_dict.get), dag_dict[max(dag_dict, key=dag_dict.get)])
    
    # Longest DAG
    most_walked = [items[0].split('_')[-1] + '_' + str(items[1]) for items in zip(resolution_list, list(max(dag_dict, key=dag_dict.get)))]
    most_walked = [most_walked[i:i+2] for i in range(0, len(most_walked), 2)]

    final = []
    for items in range(0, len(most_walked)):
        try:
            x = most_walked[items]
            final.append(x)
            v = [most_walked[items][1], most_walked[items+1][0]]
            final.append(v)
        except IndexError:
            pass
        
    sankey_dict = dict(collections.Counter(tuple([resolution_list[res].split('_')[-1]+'_'+str(dag[0]), resolution_list[res+1].split('_')[-1] + '_' + str(dag[1])]) \
                                               for res in range(0, len(resolution_list)-1) for dag in data_df[[resolution_list[res], resolution_list[res+1]]].values))
#     for res in range(0, len(resolution_list)-1):
#         for dag in data_df[[resolution_list[res], resolution_list[res+1]]].values:
#             print(res, dag)
#     print(resolution_list)
#     print(sankey_dict)
    
    
    sankey_dict={i:sankey_dict[i] for i in sankey_dict if sankey_dict[i]>30}
#     print(label)
    sankey_df = pd.DataFrame([[k[0], k[1], label.index(k[0]), label.index(k[1]), v] for k,v in sankey_dict.items()], columns=['source_label', 'target_label', 'source', 'target', 'value'])

    sankey_df['source_full'] = ['integrated_snn_'+x.split('_')[0] for x in sankey_df['source_label'].tolist()]
    sankey_df['source_cluster'] = [x.split('_')[1] for x in sankey_df['source_label'].tolist()]

    sankey_df['target_full'] = ['integrated_snn_'+x.split('_')[0] for x in sankey_df['target_label'].tolist()]
    sankey_df['target_cluster'] = [x.split('_')[1] for x in sankey_df['target_label'].tolist()]
    
    
    # Dag coloring
    sankey_df['color'] = 'lightblue'
    print(final)
    for nodes in final:
        for index, row in sankey_df[['source_label', 'target_label', 'color']].iterrows():
            if row['source_label'] == nodes[0] and row['target_label'] == nodes[1]:
                sankey_df.at[index, 'color'] = 'red'

    # Get average expression of each cell transition
    marker_list = []
    for marker_genes in [gene]:
        for cluster_data in sankey_df[['source_full', 'source_cluster', 'target_full', 'target_cluster']].values:
            print(cluster_data)
            cell_ids = data_df[(data_df[cluster_data[0]]==cluster_data[1]) & (data_df[cluster_data[2]]==cluster_data[3])].index.tolist()
            marker = cell_df.loc[cell_ids][marker_genes].mean()
            marker_list.append(marker)

    sankey_df['exp_color'] = marker_list
    
    # Expression level adjustment
    rgb_list = []
    for exp in marker_list:
        if exp <0:
            rgb_list.append('rgba(' + str(abs(300/exp)) +', 150, 200,' + str(abs(exp)) + ')')
        else:
            rgb_list.append('rgba(' + str(300 * exp) + ',150,' + '200,' + str(exp) + ')')
    sankey_df['rgba_color'] = rgb_list
    return sankey_df, label

In [26]:
# data_df.csv
data_df = pd.read_csv('../../bruneauHeartHub-TBX5/scripts/data_df.csv')
cell_df = pd.read_csv('../../bruneauHeartHub-TBX5/scripts/cell_df.csv')

data_df = data_df.set_index('index')
cell_df = cell_df.set_index('Unnamed: 0')

  interactivity=interactivity, compiler=compiler, result=result)


In [31]:
# gene = 'NR2F2'
# sankey_df, label = CellLayers(data_df, cell_df, gene)

In [115]:
fig = go.Figure(data=[go.Sankey(
    node = dict(
        pad = 20,
        thickness = 2,
        line = dict(color = "black", width = 0.5),
#         label=label,
        color='purple'
    ),
    link = dict(
        source = sankey_df['source'],
        target = sankey_df['target'],
        color = sankey_df['rgba_color'],
        value = sankey_df['value']))])

fig.update_layout(
    hovermode = 'x',
    font=dict(size = 10, color = 'white'),
    plot_bgcolor='black',
    paper_bgcolor='black'
)


plotly.offline.plot(fig, filename='tbx5_cm_biorep_louvain_decomposition.html')

# fig.update_layout(title_text="Louvain Resolution Decomposition of TBX5 Cardiomyocyte Biological Replicates", font_size=10)
# fig.show()

'tbx5_cm_biorep_louvain_decomposition.html'

In [85]:
# Dag
fig = go.Figure(data=[go.Sankey(
    node = dict(
        pad = 20,
        thickness = 2,
        line = dict(color = "black", width = 0.5),
        label = label,
        color='purple'
    ),
    link = dict(
        source = sankey_df['source'],
        target = sankey_df['target'],
        color =  sankey_df['color'],
        value = sankey_df['value'],
        label = data_df['predictions']
    ))])
plotly.offline.plot(fig, filename='tbx5_dag.html')


'tbx5_dag.html'