## Visualize the cell-type colocalization and their GSEA and Sankey diagrams

SanKey data formatter modified from [Data Vizardry By Viraj Deshpande](https://virajdeshpande.wordpress.com/portfolio/sankey-diagram/)

In [None]:
!date

#### import libraries

In [None]:
from pandas import read_csv, concat, read_parquet, DataFrame
from os.path import exists
from numpy import log10, log
from scipy.stats import zscore
from sklearn.preprocessing import MinMaxScaler

#### set notebook variables

In [None]:
# parameters
gene_set = 'MSigDB_Hallmark_2020' # 'GO_Cellular_Component_2021', 'GO_Biological_Process_2021', 'KEGG_2021_Human', 'MSigDB_Hallmark_2020'

In [None]:
# naming
cohort = 'foundin'
dx = 'PD'
day = 'daNA'
target_cell = 'DAn-meta'

# directories
wrk_dir = '/labshare/raph/datasets/foundin_qtl'
results_dir = f'{wrk_dir}/results'
public_dir = f'{wrk_dir}/public'
figures_dir = f'{wrk_dir}/figures'

# in files
index_variants_file = f'{public_dir}/nalls_pd_gwas/index_variants.list'
gsea_file = f'{figures_dir}/{cohort}.colocalization.{dx}.{gene_set}.gsea_enrichr.csv'

# out files
figure_file = f'{figures_dir}/{cohort}.colocalization.{dx}.{gene_set}.sankey.html'
target_cell_figure_file = f'{figures_dir}/{cohort}.colocalization.{dx}.{gene_set}.sankey.{target_cell}.html'
cell_figure_file = f'{figures_dir}/{cohort}.colocalization.{dx}.{gene_set}.sankey.cellpair.html'
gsea_figure_file = f'{figures_dir}/{cohort}.colocalization.{dx}.{gene_set}.sankey.gseapair.html'

# variables
DEBUG = False
modalities = ['Bryois-Astro', 'Bryois-Endo', 'Bryois-ExN', 'Bryois-InN', 
              'Bryois-Micro', 'Bryois-OPC', 'Bryois-Oligo', 'Bryois-Peri', 
              'DAn-meta']
dpi_value = 100
min_h4 = 0.5
alpha_value = 0.05
link_cols = ['source', 'target', 'weight']

### load input data

#### for each day and modality load the colocalization results

In [None]:
coloc_df = None
for modality in modalities:
    print(day, modality, end=':')
    in_file = f'{results_dir}/{cohort}_{day}_{modality}_{dx}.coloc.pp.csv'
    if exists(in_file):
        this_df = read_csv(in_file)
        print(f'loaded {this_df.shape[0]} results')
        # add day and modality
        this_df['day'] = day
        this_df['modality'] = modality
        coloc_df = concat([coloc_df, this_df])
print(f'\ntotal results loaded {coloc_df.shape[0]}')
if DEBUG:
    display(coloc_df.sample(5))
    display(coloc_df.day.value_counts())
    display(coloc_df.modality.value_counts()) 

#### subset based on the minimum H4 variable

In [None]:
# temp = coloc_df.loc[coloc_df.H4 > min_h4]
# print(f'{temp.feature.nunique()} traits will be used')
# coloc_df = coloc_df.loc[coloc_df.feature.isin(temp.feature)]
coloc_df = coloc_df.loc[coloc_df.H4 > min_h4]
print(f'results shape after filter on H4 {coloc_df.shape}')
if DEBUG:
    display(coloc_df.head())
    display(coloc_df.modality.value_counts()) 

#### create the feature to cell-type colocalization links

In [None]:
feat_cell_links = coloc_df[['feature', 'modality', 'H4']].copy()
print(f'feature to cell-type colocalization links shape {feat_cell_links.shape}')
feat_cell_links.columns = link_cols
# standarize the values
feat_cell_links.weight =  MinMaxScaler().fit_transform(zscore(feat_cell_links.weight).values.reshape(-1, 1))+0.1*10
if DEBUG:
    display(feat_cell_links.head())

#### load the colocaliztion GSEA Enrichr results

In [None]:
gsea_df = read_csv(gsea_file, index_col=0)
print(f'shape of GSEA results {gsea_df.shape}')
# subset to stat significant terms
# temp = gsea_df.loc[gsea_df.bh_fdr <= alpha_value]
# gsea_df = gsea_df.loc[gsea_df.Term.isin(temp.Term)]
# gsea_df = gsea_df.loc[gsea_df.bh_fdr <= alpha_value]
gsea_df = gsea_df.loc[gsea_df['Adjusted P-value'] <= alpha_value]
print(f'shape of GSEA statistically significant results {gsea_df.shape}')
# subset to just modalities of interest
gsea_df = gsea_df.loc[gsea_df.modality.isin(modalities)]
print(f'shape of GSEA in selected modalities {gsea_df.shape}')
if DEBUG:
    display(gsea_df.head())

#### clean-up the GO term entity

In [None]:
# gsea_df['Gene_set'] = gsea_df.Gene_set.str.replace('GO_','')
# gsea_df['Term'] = gsea_df.Gene_set.str.replace('_2021',': ') + gsea_df.Term
# print(f'shape of GSEA post Term naming cleanup {gsea_df.shape}')
# if DEBUG:
#     display(gsea_df.head())

#### create the weight values; -log10(p-value) or log odds

In [None]:
gsea_df['log10_pvalue'] = -log10(gsea_df['P-value'])
gsea_df['log_odds'] = log(gsea_df['Odds Ratio'])
print(f'shape of modified GSEA results {gsea_df.shape}')
if DEBUG:
    display(gsea_df.head())

#### create the term to cell-type links

In [None]:
# term_cell_links = gsea_df[['modality', 'Term', 'log10_pvalue']].copy()
term_cell_links = gsea_df[['modality', 'Term', 'log_odds']].copy()
print(f'term to cell-type links shape {term_cell_links.shape}')
term_cell_links.columns = link_cols
# standarize the values
term_cell_links.weight =  MinMaxScaler().fit_transform(zscore(term_cell_links.weight).values.reshape(-1, 1))+0.1
if DEBUG:
    display(term_cell_links.head())

#### fill None for any cell-types with empty GSEA enrichment

In [None]:
lists_to_add = []
missing_modals = set(modalities) - set(term_cell_links.source)
print(missing_modals)
for modality in missing_modals:
    this_item = [modality, 'No Enrichments', 1]
    lists_to_add.append(this_item)
misssing_df = DataFrame(data=lists_to_add, columns=link_cols)
print(f'shape of missing modalities {misssing_df.shape}')
term_cell_links = concat([term_cell_links, misssing_df])
print(f'updated term to cell-type links shape {term_cell_links.shape}')
if DEBUG:
    display(term_cell_links.head())
    display(term_cell_links.tail())    

### combine the link data

In [None]:
links_df = concat([feat_cell_links, term_cell_links])
# links_df = feat_cell_links.append(term_cell_links)
print(f'shape of all links to include {links_df.shape}')
if DEBUG:
    display(links_df.head())

### visualize as Sankey diagram

In [None]:
import plotly.offline as pyoff

# function from Viraj Deshpande at https://virajdeshpande.wordpress.com/portfolio/sankey-diagram/
def genSankey(df,cat_cols=[],value_cols='',title='Sankey Diagram'):
    # # maximum of 6 value cols -> 6 colors
    # colorPalette = ['#FFD43B','#646464','#4B8BBE','#306998']
    labelList = []
    # colorNumList = []
    for catCol in cat_cols:
        labelListTemp =  list(set(df[catCol].values))
        # colorNumList.append(len(labelListTemp))
        labelList = labelList + labelListTemp
        
    # remove duplicates from labelList
    labelList = list(dict.fromkeys(labelList))
    
    # # define colors based on number of levels
    # colorList = []
    # for idx, colorNum in enumerate(colorNumList):
    #     colorList = colorList + [colorPalette[idx]]*colorNum
        
    # transform df into a source-target pair
    for i in range(len(cat_cols)-1):
        if i==0:
            sourceTargetDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            sourceTargetDf.columns = ['source','target','count']
        else:
            tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            tempDf.columns = ['source','target','count']
            sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
        sourceTargetDf = sourceTargetDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
        
    # add index for source-target pair
    sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
    sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))
    
    # creating the sankey diagram
    data = dict(
        type='sankey',
        node = dict(
          pad = 15,
          thickness = 20,
          line = dict(
            color = "purple",
            width = 0.5
          ),
          label = labelList,
          # color = colorList
            color = 'purple'
        ),
        link = dict(
          source = sourceTargetDf['sourceID'],
          target = sourceTargetDf['targetID'],
          value = sourceTargetDf['count']
        )
      )
    
    layout =  dict(
        title = title,
        font = dict(
          size = 10
        )
    )
       
    fig = dict(data=[data], layout=layout)
    return fig

In [None]:
fig = genSankey(links_df, cat_cols=['source','target'], value_cols='weight', 
                title=f'PD sporadic risk: Cell-types and {gene_set} terms')
pyoff.plot(fig, validate=False, filename=figure_file)

In [None]:
target_cell_links = links_df.loc[(links_df.source == target_cell) | (links_df.target == target_cell)]
fig = genSankey(target_cell_links, cat_cols=['source','target'], value_cols='weight', 
                title=f'PD sporadic risk: {target_cell} and {gene_set} terms')
pyoff.plot(fig, validate=False, filename=target_cell_figure_file)

In [None]:
fig = genSankey(feat_cell_links, cat_cols=['source','target'], value_cols='weight', 
                title='PD sporadic risk: Features and Cell-types')
pyoff.plot(fig, validate=False, filename=cell_figure_file)

In [None]:
fig = genSankey(term_cell_links, cat_cols=['source','target'], value_cols='weight', 
                title=f'PD sporadic risk: Cell-types and {gene_set} terms')
pyoff.plot(fig, validate=False, filename=gsea_figure_file)

In [None]:
!date