In [1]:
import sys
sys.path.append('../../../')

import mudata

from src.evaluation import compute_categorical_association
from src.evaluation import compute_explained_variance_ratio
from src.evaluation import compute_geneset_enrichment
from src.evaluation import compute_trait_enrichment
from src.evaluation import compute_motif_enrichment

from plotly.subplots import make_subplots

import plotly.graph_objects as go
import plotly.express as px

import ipywidgets as w
from IPython.display import display

import numpy as np
import pandas as pd
from scipy import stats

from tqdm.auto import tqdm

In [2]:
# Import test data
mdata = mudata.read('../../../../../../data/TeloHAEC_Perturb-seq_2kG/2kG.library_K60_kangh.h5mu')

sel_idx = []
for batch in mdata['cNMF'].obs['batch'].unique():
    for samp in mdata['cNMF'].obs['sample'].unique():
        mdata_obs_ = mdata['cNMF'].obs.loc[(mdata['cNMF'].obs['batch']==batch) & \
                                           (mdata['cNMF'].obs['sample']==samp)]

        sel_idx.extend(mdata_obs_.iloc[:200].index.tolist())

mdata = mdata[sel_idx].copy()
mdata



In [3]:
# Explained variance ratio plot
explained_variance = compute_explained_variance_ratio(mdata, prog_key='cNMF', data_key='rna', inplace=False)

# Since original data wasn't available we use this as a stand in
from sklearn.decomposition import PCA

pca = PCA(n_components=60)

explained_variance['VarianceExplained'] = pca.fit(mdata['cNMF'].X).explained_variance_ratio_*0.76
explained_variance = explained_variance.reset_index()
explained_variance['ProgramID'] = explained_variance['index'].apply(lambda x: 'K60_{}'.format(x))
explained_variance = explained_variance.drop(['index', 'explained_variance_ratio_X'], axis=1)

# Categorical association plot
association_df, posthoc_df = compute_categorical_association(mdata, prog_key='cNMF',
                                                             pseudobulk_key='sample',
                                                             categorical_key='batch',
                                                             inplace=False, 
                                                             n_jobs=-1)

# # Geneset enrichment 
# gsea_df = compute_geneset_enrichment(mdata, prog_key='cNMF', library='Reactome_2013', inplace=False, n_jobs=-1)

# # Trait enrichment 
# gwas_df = compute_trait_enrichment(mdata, '../../smk/resources/OpenTargets_L2G_Filtered.csv.gz', prog_key='cNMF', inplace=False)

# # Motif enrichment
# motif_match_df, motif_count_df, motif_enrichment_df = \
# compute_motif_enrichment(mdata, prog_key='cNMF', data_key='rna', motif_file='../tests/test_data/motifs.meme',
#                          seq_file='../../../../../data/hg38/hg38.fa', coords_file='../tests/test_data/p2g_links.txt',
#                          n_jobs=1, inplace=False)

Computing explained variance:   0%|          | 0/60 [00:00<?, ?programs/s]

INFO:root:Perform testing by averaging over sample


Testing batch association:   0%|          | 0/60 [00:00<?, ?programs/s]

  prog_data = prog_data.groupby([pseudobulk_key, categorical_key]).mean().dropna().reset_index()
  prog_data = prog_data.groupby([pseudobulk_key, categorical_key]).mean().dropna().reset_index()
  prog_data = prog_data.groupby([pseudobulk_key, categorical_key]).mean().dropna().reset_index()
  prog_data = prog_data.groupby([pseudobulk_key, categorical_key]).mean().dropna().reset_index()
  prog_data = prog_data.groupby([pseudobulk_key, categorical_key]).mean().dropna().reset_index()
  prog_data = prog_data.groupby([pseudobulk_key, categorical_key]).mean().dropna().reset_index()
  prog_data = prog_data.groupby([pseudobulk_key, categorical_key]).mean().dropna().reset_index()
  prog_data = prog_data.groupby([pseudobulk_key, categorical_key]).mean().dropna().reset_index()
  prog_data = prog_data.groupby([pseudobulk_key, categorical_key]).mean().dropna().reset_index()
  prog_data = prog_data.groupby([pseudobulk_key, categorical_key]).mean().dropna().reset_index()
  prog_data = prog_data.groupb

Identifying differential batch:   0%|          | 0/60 [00:00<?, ?programs/s]

In [4]:
# Save trait enrichment data
# gwas_df.to_csv('dashapp/example_data/OpenTargets_L2G_cNMF_program_Enrichment_Results.csv')

In [5]:
# Load example data
evaluation_output = pd.ExcelFile('example_data/cNMF_evaluation_output.xlsx')

explained_variance = pd.read_excel(evaluation_output, sheet_name='explained_variance')
enrichment_motif = pd.read_excel(evaluation_output, sheet_name='enrichment_motif')
enrichment_gsea = pd.read_excel(evaluation_output, sheet_name='enrichment_gsea')
enrichment_trait = pd.read_excel(evaluation_output, sheet_name='enrichment_trait')

program_statistics = pd.read_excel(evaluation_output, sheet_name='program_statistics')
regulator_statistitcs = pd.read_excel(evaluation_output, sheet_name='regulator_statistics')


In [6]:
# Assemble dashboard - page 1

# Plot unique terms per program
def count(categorical_var, count_var, dataframe):

    counts_df = dataframe.value_counts([categorical_var, count_var])
    counts_df = counts_df.groupby(categorical_var).sum()
    counts_df = counts_df.sort_values(ascending=False)

    counts_df = pd.DataFrame(counts_df.reset_index().values, 
                             columns=[categorical_var,
                                      count_var])
                                      
    return counts_df

def count_unique(categorical_var, count_var, dataframe, 
                 cummul=False, unique=False):

    counts_df = count(categorical_var, count_var, dataframe)

    new_df = []
    terms = []
    for prog in counts_df[categorical_var].unique():
        
        terms_ = dataframe.loc[dataframe[categorical_var]==prog, count_var].unique()
        unique_terms =  [term for term in terms_ if term not in terms]

        terms.extend(unique_terms)
        new_df.append([prog, len(unique_terms)])

    new_df = pd.DataFrame(new_df, columns=[categorical_var, count_var])

    if cummul:
        new_df[count_var] = new_df[count_var].cumsum()

    if unique:
        return counts_df
    else:
        return new_df

fig = make_subplots(
    rows=3, cols=2,
    specs=[
           [{}, {}],
           [{}, {}],
           [{}, {}]
          ],
    print_grid=True,
    subplot_titles=(' component wise R2 scores', 
                    'Number of enriched GO terms',
                    'Number of enriched promoter Motifs', 
                    'Number of enriched enhancer Motifs',
                    'Number of enriched GWAS traits'),
    vertical_spacing = 0.05, horizontal_spacing = 0.1)


cummul=False
unique=False

# Plot unique GSEA Terms per program
gsea_unique_df = count_unique('ProgramID', 'ID', enrichment_gsea.loc[enrichment_gsea['qvalue']<=0.05],
                              cummul=cummul, unique=unique)

# Plot unique GWAS Terms per program
gwas_unique_df = count_unique('ProgramID', 'Term', enrichment_trait.loc[enrichment_trait['Adjusted P-value']<=0.05],
                              cummul=cummul, unique=unique)

# Plot unique Promoter Motif Terms per program
motif_unique_promoter_df = count_unique('ProgramID', 'TFMotif', enrichment_motif.loc[(enrichment_motif['FDR']<=0.05) &\
                                                                                     (enrichment_motif['EPType']=='Promoter')],
                              cummul=cummul, unique=unique)

# Plot unique Enhancer Motif Terms per program
motif_unique_enhancer_df = count_unique('ProgramID', 'TFMotif', enrichment_motif.loc[(enrichment_motif['FDR']<=0.05) &\
                                                                                     (enrichment_motif['EPType']=='Enhancer')],
                              cummul=cummul, unique=unique)

# Generate example data for dashapp

a = explained_variance.copy()
if cummul:
    a[' Component wise R2 scores'] = a['VarianceExplained'].cumsum()
else:
    a[' Component wise R2 scores'] = a['VarianceExplained']    
a = a.drop(['VarianceExplained'], axis=1)

gsea_unique_df['Number of enriched GO terms'] = gsea_unique_df['ID']
a = a.merge(gsea_unique_df.loc[:, ['ProgramID', 'Number of enriched GO terms']], 
            left_on='ProgramID', right_on='ProgramID', how='left')

motif_unique_promoter_df['Number of enriched promoter Motifs'] = motif_unique_promoter_df['TFMotif']
a = a.merge(motif_unique_promoter_df.loc[:, ['ProgramID', 'Number of enriched promoter Motifs']], 
            left_on='ProgramID', right_on='ProgramID', how='left')

motif_unique_enhancer_df['Number of enriched enhancer Motifs'] = motif_unique_enhancer_df['TFMotif']
a = a.merge(motif_unique_enhancer_df.loc[:, ['ProgramID', 'Number of enriched enhancer Motifs']], 
            left_on='ProgramID', right_on='ProgramID', how='left')

gwas_unique_df['Number of enriched GWAS traits'] = gwas_unique_df['Term']
a = a.merge(gwas_unique_df.loc[:, ['ProgramID', 'Number of enriched GWAS traits']], 
            left_on='ProgramID', right_on='ProgramID', how='left')


for col in ['Number of enriched GWAS traits',
            'Number of enriched GO terms',
            'Number of enriched promoter Motifs',
            'Number of enriched enhancer Motifs']:

    if cummul:
        a[col] = a[col].fillna(a[col].max())
    else:
        a[col] = a[col].fillna(0)        

plots = {}
for i, col in enumerate([col for col in a.columns if 'ProgramID' not in col]):

    plots[col] = px.scatter(x=a.sort_values(col, ascending=cummul)['ProgramID'], 
                            y=a.sort_values(col, ascending=cummul)[col])
    plots[col].update_layout(xaxis_title='Components', yaxis_title=col)

    row_num = int((i)/2) + 1
    col_num = i - (row_num-1)*2 + 1
    fig.add_trace(plots[col]['data'][0], row=row_num, col=col_num)
    fig.update_xaxes(showticklabels=False, row=row_num, col=col_num)
    fig.update_yaxes(ticksuffix = "  ", row=row_num, col=col_num)

# Update plots
fig.update_traces(hovertemplate="Program Name: %{x} <br> Value: %{y}", marker_color='black')

# Add corrmap
corr_map = px.imshow(a.iloc[:,1:].corr(method='spearman'), text_auto=True)
fig.add_trace(corr_map['data'][0], row=3, col=2)
fig.update_yaxes(showticklabels=False, row=3, col=2)


fig.update_layout(height=1050, width=1000, 
                  coloraxis_showscale=False,
                  plot_bgcolor='whitesmoke',
                  title_text="GEP Dashboard - v0.1 - Goodness of fit measures")
                  
fig.write_html('example_output/0_goodness_of_fit.html')
# fig.show()

This is the format of your plot grid:
[ (1,1) x,y   ]  [ (1,2) x2,y2 ]
[ (2,1) x3,y3 ]  [ (2,2) x4,y4 ]
[ (3,1) x5,y5 ]  [ (3,2) x6,y6 ]



In [34]:
# Page 2 - program assessment

# Batch assocation
association_df['batch_sample_kruskall_wallis_fdr'] = stats.false_discovery_control(association_df.batch_sample_kruskall_wallis_pval)
association_df['batch_sample_kruskall_wallis_neg_log_fdr'] = association_df['batch_sample_kruskall_wallis_fdr'].apply(lambda x: -np.log(x))

association_df['batch_sample_kruskall_wallis_log_stat'] = association_df['batch_sample_kruskall_wallis_stat'].apply(lambda x: np.log(x))

# Program scores across batch
prog_df = pd.DataFrame(mdata['cNMF'].X, index=mdata['cNMF'].obs.index)
prog_df['sample'] = mdata['cNMF'].obs['sample']
prog_df['batch'] = mdata['cNMF'].obs['batch']

# Loadings
loadings = pd.DataFrame(mdata['cNMF'].varm['loadings'], 
                        index=mdata['cNMF'].var_names,
                        columns=mdata['cNMF'].uns['var_names']).T    
loadings = loadings/abs(loadings).max(axis=0)

# Regulators
regulatory = regulator_statistitcs.loc[:, ['Perturbation', 'ProgramsRegulated', 'log2FC', 'AcrossProgramsFDR']]

# Make figure
fig = make_subplots(rows=4, cols=2,
                    specs=[
                           [{}, {}],
                           [{}, {}],
                           [{'colspan':2}, None],
                           [{}, {}],
                          ],
                    print_grid=True,
                    subplot_titles=('Enrichment w.r.t. batch', 
                                    'MAGMA regression',
                                    'Program distribution across batch',  
                                    'Program-gene KD effect',    
                                    'Program-gene loadings (normalised)'
                                    ),
                    vertical_spacing = 0.1, 
                    horizontal_spacing = 0.1)

for k, r in enumerate([col for col in prog_df.columns if col not in ['sample', 'batch']]):
       # Plot volcano
       volcano = go.Scatter(x=association_df.batch_sample_kruskall_wallis_log_stat,
                            y=association_df.batch_sample_kruskall_wallis_neg_log_fdr,
                            customdata=association_df.index.values,
                            hovertemplate=" Program Name: %{customdata}",
                            showlegend=False, mode='markers', 
                            visible=True if k == 0 else False,
                            marker_color='black',
                            )
       fig.add_trace(volcano, row=1, col=1)
       fig.update_xaxes(showticklabels=False, row=1, col=1)
       fig.update_yaxes(title='Neg. log adjusted pval', ticksuffix = "  ", row=1, col=1)

# Plot box
for k, r in enumerate([col for col in prog_df.columns if col not in ['sample', 'batch']]):

    fig.add_trace(
                  go.Box(x=prog_df.batch, 
                         y=prog_df[r],
                         visible=True if k == 0 else False,
                         customdata=prog_df.index.values,
                         hovertemplate="Cell barcode: %{customdata}",
                         marker_color='black',
                        ), 
                   row=2, col=1)
# fig.update_xaxes(showticklabels=False, row=2, col=1)
fig.update_yaxes(ticksuffix = "  ", row=2, col=1)

# Plot loadings
for k, r in enumerate(loadings.columns):

    dfp = loadings.sort_values(r, ascending=False)[:100]
    dfp['regulatory'] = None 
    fig.add_trace(
                  go.Bar(x=dfp.index, 
                         y=dfp[r],
                         name='', 
                         orientation='v',
                         hovertemplate="Gene Name: %{x} <br> Loading: %{y}",
                         visible=True if k == 0 else False,
                         marker_color='black',
                        ), 
                   row=3, col=1)
fig.update_xaxes(showticklabels=False, row=3, col=1)
fig.update_yaxes(tickvals=np.arange(0,1.25,0.25), ticksuffix = "  ", row=3, col=1)

# Plot regulators
for k, r in enumerate(loadings.columns):

       regulatory_ = regulatory.loc[regulatory.ProgramsRegulated=='K60_{}'.format(r)]

       # Plot volcano
       volcano = go.Scatter(x=regulatory_.log2FC,
                            y=regulatory_.AcrossProgramsFDR.apply(lambda x: -np.log10(x)),
                            customdata=regulatory_.Perturbation.values,
                            hovertemplate=" Perturbed gene: %{customdata}",
                            showlegend=False, mode='markers', 
                            visible=True if k == 0 else False,
                            marker_color='black',
                            )
       fig.add_trace(volcano, row=2, col=2)
       fig.update_xaxes(showticklabels=False, row=2, col=2)
       fig.update_yaxes(title='Neg. log adjusted pval', ticksuffix = "  ", row=2, col=2)

# Plot regulators
for k, r in enumerate(loadings.columns):

       # Plot volcano
       volcano = go.Scatter(x=program_statistics.CAD_MagmaBeta,
                            y=program_statistics.CAD_MagmaPvalue.apply(lambda x: -np.log10(x)),
                            customdata=program_statistics.ProgramID.values,
                            hovertemplate=" Program: %{customdata}",
                            showlegend=False, mode='markers', 
                            visible=True if k == 0 else False,
                            marker_color='black',
                            )
       fig.add_trace(volcano, row=1, col=2)
       fig.update_xaxes(showticklabels=False, row=1, col=2)
       fig.update_yaxes(title='Neg. log adjusted pval', ticksuffix = "  ", row=1, col=2)
       	
        
# Define buttons for dropdown
col_opts = list(loadings.columns)
buttons_opts = []
for i, opt in enumerate(col_opts):
    args = [False] * len(col_opts)
    args[i] = True
    buttons_opts.append(
        dict(
             method='update',
             label=opt,
             args=[{
                    'visible': args, #this is the key line!
                    'title': opt,
                    'showlegend': False
                   }]
            ))
    
fig.update_layout(height=1400, width=1000, 
                  plot_bgcolor='whitesmoke',

                  title_text="GEP Dashboard - v0.1 - Investigate GEPs",
                  updatemenus = [go.layout.Updatemenu(
                                 active=0,
                                 buttons=buttons_opts,
                                 x=-0.1,
                                 xanchor='left',
                                 y=1,
                                 yanchor='bottom')]
                 )
fig.write_html('example_output/1_covarite_association.html')
#fig.show()

This is the format of your plot grid:
[ (1,1) x,y   ]  [ (1,2) x2,y2 ]
[ (2,1) x3,y3 ]  [ (2,2) x4,y4 ]
[ (3,1) x5,y5           -      ]
[ (4,1) x6,y6 ]  [ (4,2) x7,y7 ]



In [16]:
program_statistics.columns

Index(['ProgramID', 'CuratedLabel', 'CategoryEndothelialCellSpecificityZscore',
       'CategoryBatchCorrelation', 'ProgramCategoryLabel', 'ProgramGenesTop10',
       'ProgramGenesMotifsPromoter', 'ProgramGenesMotifsEnhancer',
       'ProgramGeneSetsTop10', 'ProgramGenesMAPK', 'ProgramGenesECM',
       'ProgramGenesECCandidate', 'ProgramGenesKnownCAD',
       'RegulatorsECCandidate', 'RegulatorsKnownCAD', 'RegulatorsMAPK',
       'RegulatorsTF', 'CAD_MagmaBeta', 'CAD_MagmaPvalue',
       'CAD_RegulatorECCandidateEnrichment',
       'CAD_ProgramGeneECCandidateEnrichment', 'CAD_ECCandidateFDR',
       'LDSC_Enrichment', 'LDSC_PValue'],
      dtype='object')

In [31]:
regulator_statistitcs.loc[regulator_statistitcs.ProgramsRegulated=='K60_19']

Unnamed: 0,PerturbationOriginalName,Perturbation,ProgramsRegulated,log2FC,p.value,AcrossProgramsFDR,ExperimentWideFDR,IsMultiTarget,Gene.target.full.name,Gene2,TeloHAEC_ctrl_TPM,KDEfficiencyLog2FC,KDEfficiencyPValue,KDEfficiencyFDR
6,ANKRD22,ANKRD22,K60_19,0.1225,0.000171,0.010257,0.041135,False,,,0.0,-6.650000000000001e-17,1.0,1.0
86,FYCO1,FYCO1,K60_19,-0.24596,0.000118,0.007097,0.032332,False,,,16.98391,-0.236235,0.214929,0.317297
141,METAP2,METAP2,K60_19,-0.2443,8.7e-05,0.005203,0.026196,False,,,122.9232,-0.931018,1.61e-10,3.33e-09
193,PMVK,PMVK,K60_19,0.36288,1e-06,3.4e-05,0.000455,False,,,47.97226,-1.032499,3.23e-06,2.98e-05
222,RNF146,RNF146,K60_19,0.36188,0.000184,0.010216,0.042579,False,,,27.36403,-0.53029,0.013093,0.036337


In [None]:
# Program statistics

# Magma CAD_MagmaBeta	CAD_MagmaPvalue

# LDSC LDSC_Enrichment	LDSC_PValue

# Similar to loadings - genes whose KD affects the program

# Trait heritability


In [None]:
# Regulator statistics

# % KD KDEfficiencyLog2FC	KDEfficiencyPValue	KDEfficiencyFDR

# Log2FC on which program  Perturbation	ProgramsRegulated	log2FC	p.value	AcrossProgramsFDR	ExperimentWideFDR


In [None]:
# Plot trait enrichment interpretation
def plot_interactive_phewas(data, x_column='trait_reported',
                            y_column='-log10(p-value)',
                            color_column='trait_category',
                            filter_column='program_name',
                            significance_threshold=0.05,
                            annotation_cols=["program_name", "trait_reported",
                                             "trait_category", "P-value",
                                             "Genes", "study_id", "pmid"],
                           query_string="trait_category != 'measurement'",
                           title="Cell Program x OpenTargets GWAS L2G Enrichment"):
    
    # Get unique values for the filtering column
    filter_values = ['All'] + list(data[filter_column].unique())

    # Initialize output widget to display the plot
    output = Output()
    
    if query_string:
        data=data.query(query_string)

    # Function to update plot based on dropdown selection
    def update_plot(selected_value):
        # Filter data based on selected value
        if selected_value == "All":
            filtered_data = data.copy()  # No selection, show all data
        else:
            filtered_data = data[data[filter_column] == selected_value]

        # Create the plot
        fig = px.scatter(filtered_data, x=x_column, y=y_column, color=color_column,
                         title=title,
                         hover_data=annotation_cols)

        # Customize layout
        fig.update_layout(
            xaxis_title=x_column,
            yaxis_title=y_column,
            yaxis=dict(tickformat=".1f"),
            width=1000,  # Adjust width as needed
            height=800,  # Adjust height as needed,
            xaxis_tickfont=dict(size=4)
        )

        # Add horizontal dashed line for significance threshold
        fig.add_hline(y=-np.log10(significance_threshold), line_dash="dash",
                      annotation_text=f'Significance Threshold ({significance_threshold})', annotation_position="top right")

        # Clear previous plot and display the new one
        with output:
            output.clear_output(wait=True)
            fig.show()

    # Create dropdown widget
    dropdown = Dropdown(options=filter_values, description=f"{filter_column}:")

    # Define function to handle dropdown value change
    def on_change(change):
        if change['type'] == 'change' and change['name'] == 'value':
            update_plot(change['new'])

    # Link dropdown change to function
    dropdown.observe(on_change)

    # Display dropdown widget and initial plot
    display(VBox([dropdown, output]))