In [368]:
import scanpy as sc
import numpy as np
import pandas as pd
import collections
from itertools import islice

import plotly
import seaborn as sns
import scipy.io as sci
import plotly.graph_objects as go
import chart_studio.plotly as py
sc.settings.set_figure_params(dpi=80, facecolor='white')


In [2]:
def create_expression_df(adata):
    return pd.DataFrame(adata.X.toarray(), index=adata.obs.index.tolist(), columns=adata.var.index.tolist())

In [360]:
class CellLayers(object):
    
    def __init__(self, exp_df, meta_df, genes=None, edge_cutoff=30, flow_color = 'lightblue'):
        
        self.exp_df = exp_df
        self.meta_df = meta_df
        
        # Update to allow users to pass in a sankey_dict
        self.sankey_dict = {'data':None, 'label':None, 'dag_dict': None, 'largest_dag': {},
                            'flow_by_flow': None, 'edge_cutoff':edge_cutoff, 'flow_color' : flow_color, 'genes' : genes,
                            'marker_exp_dict' : {},
                            'resolutions':[data for data in list(self.meta_df) if data.split('.')[0] == 'res']}
    
    def create_node_labels(self):
        '''
        Create nodes, which are labeled as the louvain resolution and underscored by community assignment. 
        For example, louvain resolution 0.1 and community 4 will show as res.0.1_4.
        '''
        self.sankey_dict['label'] = [res for sublist in [list(map(( lambda x: res + '_' + str(x)), list(set(self.meta_df[res])))) \
         for res in self.sankey_dict['resolutions']] for res in sublist]
    
    def collect_dag_data(self):
        '''
        Collect directed acyclic graph data across the louvain decomposition and find the largest dag.
        '''
        
        # Create a dag dictionary where keys are tuples containing the dag and values are the number of cells that walk the dag
        self.sankey_dict['dag_dict'] = dict(collections.Counter(tuple(dag) \
                                                            for dag in self.meta_df[self.sankey_dict['resolutions']].values))
        
        # Update sankey dict keys to include the resolution parameter to the community assigment
        self.sankey_dict['dag_dict'] = {tuple([item[1] + '_' + str(item[0]) for item in zip(list(k), self.sankey_dict['resolutions'])]): v for k,v in self.sankey_dict['dag_dict'].items()}

        # Future direction would be to add this for identifying stable cell populations
        # Add largest dag to sankey dict
        self.sankey_dict['largest_dag'][max(self.sankey_dict['dag_dict'], key=self.sankey_dict['dag_dict'].get)] =\
        self.sankey_dict['dag_dict'][max(self.sankey_dict['dag_dict'], key=self.sankey_dict['dag_dict'].get)]
        
        # Trace the largest dag's decomposition 
        self.sankey_dict['largest_dag_flow'] = list(self.follow_the_flow(list(self.sankey_dict['largest_dag'].keys())[0], n=2))
        
        # Build sankey flow network
        self.count_flow_by_flow()
    
    def follow_the_flow(self, seq, n=2):
        '''
        Return a sliding window (of width n) over data from the iterable
        '''
        it = iter(seq)
        result = tuple(islice(it, n))
        if len(result) == n:
            yield result
        for elem in it:
            result = result[1:] + (elem,)
            yield result
    
    def count_flow_by_flow(self):
        '''
        Count the number of cells flowing from resolution (n) community (k) to resolution (n+1) community (k).
        ** This can be optimized to by using the data generated from create_node_labels**
        '''
        self.sankey_dict['flow_by_flow'] = dict(collections.Counter(tuple([self.sankey_dict['resolutions'][res]+'_'+str(dag[0]), self.sankey_dict['resolutions'][res+1] +'_'+str(dag[1])])\
              for res in range(0, len(self.sankey_dict['resolutions'])-1) \
              for dag in self.meta_df[[self.sankey_dict['resolutions'][res], self.sankey_dict['resolutions'][res+1]]].values))
        
        if self.sankey_dict['edge_cutoff'] > 0:
            self.trim_the_flow()
    
    def trim_the_flow(self):
        '''
        Remove edges based on user defined cutoff.
        '''
        self.sankey_dict['flow_by_flow']={i:self.sankey_dict['flow_by_flow'][i] for i in self.sankey_dict['flow_by_flow'] if self.sankey_dict['flow_by_flow'][i]>30}
    
    
    def add_meta_data_to_sankey_df(self, sankey_label, meta_to_add):
        self.sankey_dict['data'][meta_to_add[0]] = [x.split('_')[0] for x in self.sankey_dict['data'][sankey_label].tolist()]
        self.sankey_dict['data'][meta_to_add[1]] = [x.split('_')[1] for x in self.sankey_dict['data'][sankey_label].tolist()]
    
    def create_sankey_df(self):
        self.sankey_dict['data'] = pd.DataFrame([[k[0], k[1], self.sankey_dict['label'].index(k[0]), self.sankey_dict['label'].index(k[1]), v] for k,v in self.sankey_dict['flow_by_flow'].items()],\
                                 columns=['source_label', 'target_label', 'source', 'target', 'value'])
        self.add_meta_data_to_sankey_df('source_label', ['source_res', 'source_cluster'])
        self.add_meta_data_to_sankey_df('target_label', ['target_res', 'target_cluster'])
        self.color_my_dag()
    
    
    def color_my_dag(self):
        '''
        Color largest dag red and all other dags light blue.
        ** Update to improve iteration speed! **
        '''
        self.sankey_dict['data']['color'] = self.sankey_dict['flow_color']
#         print(list(self.sankey_dict['flow_by_flow'].keys()))
        for nodes in sankey_dict['largest_dag_flow']:
#             print(nodes)
            for index, row in self.sankey_dict['data'][['source_label', 'target_label', 'color']].iterrows():
                if row['source_label'] == nodes[0] and row['target_label'] == nodes[1]:
                    self.sankey_dict['data'].at[index, 'color'] = 'red'
                    
        if self.sankey_dict['genes']:
            self.compute_avg_expression()
    
    def compute_avg_expression(self):
        for marker_genes in self.sankey_dict['genes']:
            marker_list = []
            for cluster_data in self.sankey_dict['data'][['source_res', 'source_cluster', 'target_res', 'target_cluster']].values:
                cell_ids = self.meta_df[(self.meta_df[cluster_data[0]]==cluster_data[1]) & (self.meta_df[cluster_data[2]]==cluster_data[3])].index.tolist()
                marker = self.exp_df.loc[cell_ids][marker_genes].mean()
                marker_list.append(marker)
                self.sankey_dict['marker_exp_dict'][marker_genes] = marker_list
            self.paint_flow_by_gene_exp()
    
    def paint_flow_by_gene_exp(self):
        for k,v in self.sankey_dict['marker_exp_dict'].items():
            exp_list = []
            for exp in v:
                if exp < 0:
                    exp_list.append('rgba(' + str(abs(300/exp+0.1)) +', 150, 200,' + str(abs(exp)) + ')')
                else:
                    exp_list.append('rgba(' + str(300 * exp) + ',150,' + '200,' + str(exp) + ')')
            self.sankey_dict['data'][k + '_rgba'] = exp_list
                    
    def run(self):
        self.create_node_labels()
        self.collect_dag_data()
        self.create_sankey_df()
        return self.sankey_dict

In [362]:
sankey_dict = CellLayers(pd.DataFrame(adata.X.toarray(), index=adata.obs.index.tolist(), columns=adata.var.index.tolist()),\
                         adata.obs, genes=['MS4A1','GNLY','CD3E',\
                                          'CD14', 'FCER1A', 'FCGR3A',\
                                          'LYZ', 'PPBP', 'CD8A'],edge_cutoff=0).run()

In [283]:
adata.uns['modularity']['label'] = sankey_dict['resolutions']
adata.uns['modularity']['color'] = ['rgba(255, 230,'+str(100* modularity)  +','+ str(modularity)+')' for modularity in adata.uns['modularity']['modularity'].tolist()]
node_color_dict = {items[1]:items[0]for items in adata.uns['modularity'][['color','label']].values}
node_label_dict = {items[1]:items[0]for items in adata.uns['modularity'][['modularity','label']].values}

In [284]:
x = []
v = []
for items in sankey_dict['label']:
    v.append('Cluster ' +  items.split('_')[1] + '<br>'+ 'Resolution ' + '.'.join(items.split('_')[0].split('.')[1:]) + '<br>' +  'Modularity Score: {}'.format(node_label_dict[items.split('_')[0]]) )
    x.append(node_color_dict[items.split('_')[0]])    


In [367]:
for gene in ['MS4A1','GNLY','CD3E','CD14', 'FCER1A', 'FCGR3A','LYZ', 'PPBP', 'CD8A']:
    data_trace=dict(type='sankey', orientation='h', 
                    node = dict(pad = 10, thickness=10,
    #                             line = dict(color = "black", width = 0.5),
                                label=sankey_dict['label'],
                                customdata = v,
                                hovertemplate = '%{customdata}',
                                color='#F7ED32'),

                    link = dict(source = sankey_dict['data']['source'],
                                target = sankey_dict['data']['target'],
                                color = [sankey_dict['data'][gene+'_rgba'], sankey_dict['data']['color']][0],
                                value = sankey_dict['data']['value']))


    # 'GNLY','CD3E
    layout = dict(title = 'Expression of '+gene+' in the 10X PBMC Dataset',
                  updatemenus=[
                     dict(y=1, 
                          buttons=[dict(label='Light', method='relayout', args=['paper_bgcolor', 'white']),
                                   dict(label='Dark', method='relayout', args=['paper_bgcolor', 'black'])]),

                     dict(y=0.9,
                          buttons=[dict(label='Thin', method='restyle',args=['node.thickness', 8]),
                                   dict(label='Thick',method='restyle',args=['node.thickness', 15])]),

                     dict(y=0.7,
                         buttons=[dict(label='Snap',method='restyle', args=['arrangement', 'snap']),
                                  dict(label='Perpendicular', method='restyle',args=['arrangement', 'perpendicular']),
                                  dict(label='Freeform', method='restyle',args=['arrangement', 'freeform']),
                                  dict(label='Fixed', method='restyle',args=['arrangement', 'fixed'])]),

                     dict(y=0.8,
                          buttons=[dict(label='Small gap',method='restyle',args=['node.pad', 15]),
                                   dict(label='Large gap',method='restyle',args=['node.pad', 20])]),

                     dict(y=0.6,
                          buttons=[dict(label='Horizontal', method='restyle', args=['orientation', 'h']),
                                   dict(label='Vertical',method='restyle',args=['orientation', 'v'])]),

                 ])

    fig = dict(data=[data_trace], layout=layout)
    # fig.update_layout(
    #     font=dict(size = 10, color = 'purple'),
    # )
    # fig.show()
    plotly.offline.plot(fig, filename='PBMC/'+gene+'_PBMC.html')

In [None]:
fig=go.Figure(data=[dict(type='sankey', orientation='h', 
                node = dict(pad = 10, thickness=10,
                            label=sankey_dict['label'],
                            customdata = v,
                            hovertemplate = '%{customdata}',
                            color='#F7ED32'),
                
                link = dict(source = sankey_dict['data']['source'],
                            target = sankey_dict['data']['target'],
                            color = sankey_dict['data']['GNLY_rgba'],
                            value = sankey_dict['data']['value']))])

test = [dict(type='sankey', orientation='h', 
                node = dict(pad = 10, thickness=10,
#                             line = dict(color = "black", width = 0.5),
                            label=sankey_dict['label'],
                            customdata = v,
                            hovertemplate = '%{customdata}',
                            color='#F7ED32'),
                
                link = dict(source = sankey_dict['data']['source'],
                            target = sankey_dict['data']['target'],
                            color = sankey_dict['data']['CD3E_rgba'],
                            value = sankey_dict['data']['value']))]


fig.update_layout(
    updatemenus=[
                 dict(y=1, 
                      buttons=[dict(label='Light', method='relayout', args=['paper_bgcolor', 'white']),
                               dict(label='Dark', method='relayout', args=['paper_bgcolor', 'black'])]),
                 
                 dict(y=0.9,
                      buttons=[dict(label='Thin', method='restyle',args=['node.thickness', 8]),
                               dict(label='Thick',method='restyle',args=['node.thickness', 15])]),
                 
                 dict(y=0.7,
                     buttons=[dict(label='Snap',method='restyle', args=['arrangement', 'snap']),
                              dict(label='Perpendicular', method='restyle',args=['arrangement', 'perpendicular']),
                              dict(label='Freeform', method='restyle',args=['arrangement', 'freeform']),
                              dict(label='Fixed', method='restyle',args=['arrangement', 'fixed'])]),
                 
                 dict(y=0.8,
                      buttons=[dict(label='Small gap',method='restyle',args=['node.pad', 15]),
                               dict(label='Large gap',method='restyle',args=['node.pad', 20])]),
                 
                 dict(y=0.6,
                      buttons=[dict(label='Horizontal', method='restyle', args=['orientation', 'h']),
                               dict(label='Vertical',method='restyle',args=['orientation', 'v'])]),
                 
                 dict(type="buttons", direction="right", active=0, x=0.57, y=1.2, buttons=list([dict(label="GNLY",
                                                                                                     method="update",
                                                                                                     args=[{"Visible":[True, False],
                                                                                                           "annotations":test}])]))
             ])

# fig = dict(data=[gnly_data_trace], layout=layout)
# fig.update_layout(
#     font=dict(size = 10, color = 'purple'),
# )
# fig.show()
plotly.offline.plot(fig, filename='PBMC_GNLY.html')

In [307]:
fig = go.Figure(data=[go.Sankey(type='sankey', orientation='h', 
    node = dict(
        pad = 20,
        thickness = 2,
        line = dict(color = "black", width = 0.5),
#         label=label,
        color='purple'
    ),
    link = dict(
        source = sankey_dict['data']['source'],
        target = sankey_dict['data']['target'],
        color = sankey_dict['data']['color'],
        value = sankey_dict['data']['value']))])

fig.update_layout(
    hovermode = 'x',
    font=dict(size = 10, color = 'white'),
    plot_bgcolor='white',
    paper_bgcolor='white'
)


layout = dict(title = 'Cell Layers: 10X Genomics PBMC Dataset')

fig = dict(data=fig, layout=layout)

plotly.offline.plot(fig, filename='DAG_PBMC.html')

'DAG_PBMC.html'

In [623]:
data_df = adata.obs[['res.0.1',
 'res.0.2',
 'res.0.3',
 'res.0.4',
'res.0.5',
 'res.0.6',
 'res.0.7',
 'res.0.8',
 'res.0.9',
 'res.1',
 'res.1.1',
 'res.1.2',
 'res.1.3',
 'res.1.4',
 'res.1.5',
 'res.1.6',
 'res.1.7',
 'res.1.8',
 'res.1.9',
 'res.2']]

for column in list(data_df):
    data_df[column] = [column + '_' + x for x in data_df[column]]
new_df = pd.DataFrame(columns=['source', 'target', 'label'])


    
def follow_the_flow(seq, n=2):
    '''
    Return a sliding window (of width n) over data from the iterable
    '''
    it = iter(seq)
    result = tuple(islice(it, n))
    if len(result) == n:
        yield result
    for elem in it:
        result = result[1:] + (elem,)
        yield result

source_list = []
target_list = []
label_list = []
for index, row in data_df.iterrows():
    sankey_list = list(follow_the_flow([sankey_dict['label'].index(items) for items in list(row)]))
    for items in sankey_list:
        source_list.append(items[0])
        target_list.append(items[1])
        label_list.append(adata.obs.loc[index]['seurat_annotations'])
        
fig = go.Figure(data=[go.Sankey(
    node = dict(
        pad = 20,
        thickness = 2,
        line = dict(color = "black", width = 0.5),
#         label=label,
        color='purple'
    ),
    link = dict(
        source = source_list,
        target = target_list,
        label = label_list,
        value = [1] * len(label_list)))])

fig.update_layout(
    hovermode = 'x',
    font=dict(size = 10, color = 'white'),
    plot_bgcolor='white',
    paper_bgcolor='white'
)


plotly.offline.plot(fig, filename='PBMC.html')



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



### 1) Data wrangling and preperation

In [391]:
# Data wrangling for pbmc data
adata = sc.read('../Data/pbmc3k.h5ad')

# Normalize raw count data because converter does not transfer expression data
sc.pp.normalize_total(adata, target_sum=1e4) 
sc.pp.log1p(adata)

# Add Seurat's modularity scores to adata uns slot
modularity_df = pd.read_csv('../Data/pbmc3k_modularity.csv', index_col='res')
del modularity_df['Unnamed: 0']
adata.uns['modularity'] = modularity_df

In [394]:
adata.obs = adata.obs[['orig.ident',
 'nCount_RNA',
 'nFeature_RNA',
 'seurat_annotations',
 'percent.mt',
 'seurat_clusters',
 'RNA_snn_res.0.1',
 'RNA_snn_res.0.2',
 'RNA_snn_res.0.3',
 'RNA_snn_res.0.4',
 'RNA_snn_res.0.5',
 'RNA_snn_res.0.6',
 'RNA_snn_res.0.7',
 'RNA_snn_res.0.8',
 'RNA_snn_res.0.9',
 'RNA_snn_res.1',
 'RNA_snn_res.1.1',
 'RNA_snn_res.1.2',
 'RNA_snn_res.1.3',
 'RNA_snn_res.1.4',
 'RNA_snn_res.1.5',
 'RNA_snn_res.1.6',
 'RNA_snn_res.1.7',
 'RNA_snn_res.1.8',
 'RNA_snn_res.1.9',
 'RNA_snn_res.2']]

In [395]:
adata.obs.to_csv('pbmc_meta.csv', index=True, index_label='')

In [6]:
# Rename Seurat resolution names to shorten sankey node names
label_dict={data: data.split('_')[-1] \
                                  for data in list(adata.obs) if data.split('.')[0] == 'RNA_snn_res'}
adata.obs = adata.obs.rename(columns=label_dict)

adata.obs = adata.obs[['orig.ident',
 'nCount_RNA',
 'nFeature_RNA',
 'seurat_annotations',
 'percent.mt', 
 'seurat_clusters',
 'res.0.1',
 'res.0.2',
 'res.0.3',
 'res.0.4',
'res.0.5',
 'res.0.6',
 'res.0.7',
 'res.0.8',
 'res.0.9',
 'res.1',
 'res.1.1',
 'res.1.2',
 'res.1.3',
 'res.1.4',
 'res.1.5',
 'res.1.6',
 'res.1.7',
 'res.1.8',
 'res.1.9',
 'res.2']]

for res in ['res.0.1',
 'res.0.2',
 'res.0.3',
 'res.0.4',
'res.0.5',
 'res.0.6',
 'res.0.7',
 'res.0.8',
 'res.0.9',
 'res.1',
 'res.1.1',
 'res.1.2',
 'res.1.3',
 'res.1.4',
 'res.1.5',
 'res.1.6',
 'res.1.7',
 'res.1.8',
 'res.1.9',
 'res.2']:
    adata.obs[res] = [str(x) for x in adata.obs[res]]

In [373]:
exp_df = create_expression_df(adata)

In [375]:
exp_df.to_csv('pbmc_exp.csv')

In [None]:

# fig=dict(data=[dict(type='sankey', orientation='h', 
#                 node = dict(pad = 10, thickness=10,
#                             label=sankey_dict['label'],
#                             customdata = v,
#                             hovertemplate = '%{customdata}',
#                             color='#F7ED32'),
                
#                 link = dict(source = sankey_dict['data']['source'],
#                             target = sankey_dict['data']['target'],
#                             color = sankey_dict['data']['GNLY_rgba'],
#                             value = sankey_dict['data']['value']))])

# test = dict(data=[dict(type='sankey', orientation='h', 
#                 node = dict(pad = 10, thickness=10,
# #                             line = dict(color = "black", width = 0.5),
#                             label=sankey_dict['label'],
#                             customdata = v,
#                             hovertemplate = '%{customdata}',
#                             color='#F7ED32'),
                
#                 link = dict(source = sankey_dict['data']['source'],
#                             target = sankey_dict['data']['target'],
#                             color = sankey_dict['data']['CD3E_rgba'],
#                             value = sankey_dict['data']['value']))])


dag_trace=dict(type='sankey', orientation='h',
                node = dict(pad = 10, thickness=10,
#                             line = dict(color = "black", width = 0.5),
                            label=sankey_dict['label'],
                            customdata = v,
                            hovertemplate = '%{customdata}',
                            color='#F7ED32'),
                
                link = dict(source = sankey_dict['data']['source'],
                            target = sankey_dict['data']['target'],
                            color = sankey_dict['data']['color'],
                            value = sankey_dict['data']['value']))

# font=dict(size = 10, color = 'purple')

In [None]:
# gnly_data_trace=dict(type='sankey', orientation='h', 
#                 node = dict(pad = 10, thickness=10,
# #                             line = dict(color = "black", width = 0.5),
#                             label=sankey_dict['label'],
#                             customdata = v,
#                             hovertemplate = '%{customdata}',
#                             color='#F7ED32'),
                
#                 link = dict(source = sankey_dict['data']['source'],
#                             target = sankey_dict['data']['target'],
#                             color = sankey_dict['data']['GNLY_rgba'],
#                             value = sankey_dict['data']['value']))


# dag_trace=dict(type='sankey', orientation='h',
#                 node = dict(pad = 10, thickness=10,
# #                             line = dict(color = "black", width = 0.5),
#                             label=sankey_dict['label'],
#                             customdata = v,
#                             hovertemplate = '%{customdata}',
#                             color='#F7ED32'),
                
#                 link = dict(source = sankey_dict['data']['source'],
#                             target = sankey_dict['data']['target'],
#                             color = sankey_dict['data']['GNLY_rgba'],
#                             value = sankey_dict['data']['value']))

# font=dict(size = 10, color = 'purple')

# 'GNLY','CD3E
# layout = dict(title = 'Expression of GNLY in the 10X PBMC Dataset',