In [None]:
import sys
sys.path.append('../../../')

import mudata

from src.evaluation import compute_categorical_association
from src.evaluation import compute_explained_variance_ratio
from src.evaluation import compute_geneset_enrichment
from src.evaluation import compute_trait_enrichment
from src.evaluation import compute_motif_enrichment

from plotly.subplots import make_subplots

import plotly.graph_objects as go
import plotly.express as px

import ipywidgets as w
from IPython.display import display

import numpy as np
import pandas as pd
from scipy import stats

from tqdm.auto import tqdm

In [None]:
# Import test data
mdata = mudata.read('../../../../../../data/TeloHAEC_Perturb-seq_2kG/2kG.library_K60_kangh.h5mu')

sel_idx = []
for batch in mdata['cNMF'].obs['batch'].unique():
    for samp in mdata['cNMF'].obs['sample'].unique():
        mdata_obs_ = mdata['cNMF'].obs.loc[(mdata['cNMF'].obs['batch']==batch) & \
                                           (mdata['cNMF'].obs['sample']==samp)]

        sel_idx.extend(mdata_obs_.iloc[:200].index.tolist())

mdata = mdata[sel_idx].copy()
mdata

In [None]:
# Explained variance ratio plot
# explained_variance = compute_explained_variance_ratio(mdata, prog_key='cNMF', data_key='rna', inplace=False)

# # Since original data wasn't available we use this as a stand in
# from sklearn.decomposition import PCA

# pca = PCA(n_components=60)

# explained_variance['explained_variance_ratio_X'] = pca.fit(mdata['cNMF'].X).explained_variance_ratio_*0.93

# Categorical association plot
association_df, posthoc_df = compute_categorical_association(mdata, prog_key='cNMF',
                                                             pseudobulk_key='sample',
                                                             categorical_key='batch',
                                                             inplace=False, 
                                                             n_jobs=-1)

# # Geneset enrichment 
# gsea_df = compute_geneset_enrichment(mdata, prog_key='cNMF', library='Reactome_2013', inplace=False, n_jobs=-1)

# # Trait enrichment 
# gwas_df = compute_trait_enrichment(mdata, '../../smk/resources/OpenTargets_L2G_Filtered.csv.gz', prog_key='cNMF', inplace=False)

# # Motif enrichment
# motif_match_df, motif_count_df, motif_enrichment_df = \
# compute_motif_enrichment(mdata, prog_key='cNMF', data_key='rna', motif_file='../tests/test_data/motifs.meme',
#                          seq_file='../../../../../data/hg38/hg38.fa', coords_file='../tests/test_data/p2g_links.txt',
#                          n_jobs=1, inplace=False)

In [4]:
# Save trait enrichment data
# gwas_df.to_csv('dashapp/example_data/OpenTargets_L2G_cNMF_program_Enrichment_Results.csv')

In [None]:
# Load example data
evaluation_output = pd.ExcelFile('example_data/cNMF_evaluation_output.xlsx')

explained_variance = pd.read_excel(explained_variance, sheet_name='explained_variance')
enrichment_motif = pd.read_excel(evaluation_output, sheet_name='enrichment_motif')
enrichment_gsea = pd.read_excel(evaluation_output, sheet_name='enrichment_gsea')
enrichment_trait = pd.read_excel(evaluation_output, sheet_name='enrichment_trait')


In [None]:
# Assemble dashboard - page 1

# Plot unique terms per program
def count_unique(categorical_var, count_var, dataframe):

    counts_df = dataframe.value_counts([categorical_var, count_var])
    counts_df = counts_df.groupby(categorical_var).sum()
    counts_df = counts_df.sort_values(ascending=False)#.cumsum()

    counts_df = pd.DataFrame(counts_df.reset_index().values, 
                             columns=[categorical_var,
                                      count_var])
    return counts_df

fig = make_subplots(
    rows=3, cols=2,
    specs=[
           [{}, {}],
           [{}, {}],
           [{}, None]
          ],
    print_grid=True,
    subplot_titles=('Component wise R2 scores', 'Number of enriched GO terms',
                    'Number of enriched promoter Motifs', 'Number of enriched GWAS traits',
                    'Number of enriched enhancer Motifs'),
    vertical_spacing = 0.05, horizontal_spacing = 0.1)

# Explained variance ratios
explained_variance = explained_variance.sort_values(explained_variance['VarianceExplained'], ascending=False)

# Plot unique GSEA Terms per program
gsea_unique_df = count_unique('ProgramID', 'ID', enrichment_gsea.loc[enrichment_gsea['qvalue']<=0.05])

# Plot unique GWAS Terms per program
gwas_unique_df = count_unique('ProgramID', 'Term', enrichment_trait.loc[enrichment_trait['Adjusted P-value']<=0.05])

# Plot unique Promoter Motif Terms per program
motif_unique_promoter_df = count_unique('ProgramID', 'TFMotif', enrichment_motif.loc[(enrichment_motif['FDR']<=0.05) &\
                                                                                     (enrichment_motif['EPType']=='Promoter')])

# Plot unique Enhancer Motif Terms per program
motif_unique_enhancer_df = count_unique('ProgramID', 'TFMotif', enrichment_motif.loc[(enrichment_motif['FDR']<=0.05) &\
                                                                                     (enrichment_motif['EPType']=='Enhancer')])

# Generate example data for dashapp

a = explained_ratios
a['Component wise R2 scores'] = a['VarianceExplained']
a = a.drop(['VarianceExplained'], axis=1)

gsea_unique_df['Number of enriched GO terms'] = gsea_unique_df['ID']
a = a.merge(gsea_unique_df.loc[:, ['ProgramID', 'Number of enriched GO terms']], 
            left_on='ProgramID', right_on='ProgramID', how='left')

motif_unique_promoter_df['Number of enriched promoter Motifs'] = motif_unique_promoter_df['TFMotif']
a = a.merge(motif_unique_promoter_df.loc[:, ['ProgramID', 'Number of enriched promoter Motifs']], 
            left_on='ProgramID', right_on='ProgramID', how='left')

motif_unique_enhancer_df['Number of enriched enhancer Motifs'] = motif_unique_enhancer_df['TFMotif']
a = a.merge(motif_unique_enhancer_df.loc[:, ['ProgramID', 'Number of enriched enhancer Motifs']], 
            left_on='ProgramID', right_on='ProgramID', how='left')

gwas_unique_df['Number of enriched GWAS traits'] = gwas_unique_df['Term']
a = a.merge(gwas_unique_df.loc[:, ['ProgramID', 'Number of enriched GWAS traits']], 
            left_on='ProgramID', right_on='ProgramID', how='left')


a['Number of enriched GWAS traits'] = a['Number of enriched GWAS traits'].fillna(0)
a['Number of enriched GO terms'] = a['Number of enriched GO terms'].fillna(0)
a['Number of enriched promoter Motifs'] = a['Number of enriched promoter Motifs'].fillna(0)
a['Number of enriched enhancer Motifs'] = a['Number of enriched enhancer Motifs'].fillna(0)


plots = {}
for i, col in enumerate([col for col in a.columns if 'ProgramID' not in col]):

    plots[col] = px.scatter(x=a.sort_values(col, ascending=False)['ProgramID'], 
                            y=a.sort_values(col, ascending=False)[col],)
    plots[col].update_layout(xaxis_title='Components', yaxis_title=col)

    row_num = int((i)/2) + 1
    col_num = i - (row_num-1)*2 + 1
    fig.add_trace(plots[col]['data'][0], row=row_num, col=col_num)
    fig.update_xaxes(showticklabels=False, row=row_num, col=col_num)
    fig.update_yaxes(ticksuffix = "  ", row=row_num, col=col_num)

fig.update_traces(hovertemplate="Program Name: %{x} <br> Value: %{y}", marker_color='black')
fig.update_layout(height=1050, width=1000, 
                  plot_bgcolor='whitesmoke',
                  title_text="GEP Dashboard - v0.1 - Goodness of fit measures")
                  
fig.write_html('example_output/0_goodness_of_fit.html')


In [None]:
# Page 2 - program assessment

# Batch assocation
association_df['batch_sample_kruskall_wallis_fdr'] = stats.false_discovery_control(association_df.batch_sample_kruskall_wallis_pval)
association_df['batch_sample_kruskall_wallis_neg_log_fdr'] = association_df['batch_sample_kruskall_wallis_fdr'].apply(lambda x: -np.log(x))

association_df['batch_sample_kruskall_wallis_log_stat'] = association_df['batch_sample_kruskall_wallis_stat'].apply(lambda x: np.log(x))

# Program scores across batch
prog_df = pd.DataFrame(mdata['cNMF'].X, index=mdata['cNMF'].obs.index)
prog_df['sample'] = mdata['cNMF'].obs['sample']
prog_df['batch'] = mdata['cNMF'].obs['batch']

# Loadings
loadings = pd.DataFrame(mdata['cNMF'].varm['loadings'], 
                        index=mdata['cNMF'].var_names,
                        columns=mdata['cNMF'].uns['var_names']).T    
loadings = loadings/abs(loadings).max(axis=0)

# Make figure
fig = make_subplots(rows=2, cols=2,
                    specs=[
                           [{}, {}],
                           [{'colspan':2}, None],
                          ],
                    print_grid=True,
                    subplot_titles=('Enrichment w.r.t. batch', 
                                    'Program distribution across batch', 
                                    'Program-gene loadings (normalised)'),
                    vertical_spacing = 0.1, 
                    horizontal_spacing = 0.1)

for k, r in enumerate([col for col in prog_df.columns if col not in ['sample', 'batch']]):
       # Plot volcano
       volcano = go.Scatter(x=association_df.batch_sample_kruskall_wallis_log_stat,
                            y=association_df.batch_sample_kruskall_wallis_neg_log_fdr,
                            customdata=association_df.index.values,
                            hovertemplate=" Program Name: %{customdata}",
                            showlegend=False, mode='markers', 
                            visible=True if k == 0 else False,
                            marker_color='black',
                            )
       fig.add_trace(volcano, row=1, col=1)
       fig.update_xaxes(showticklabels=False, row=1, col=1)
       fig.update_yaxes(title='Neg. log adjusted pval', ticksuffix = "  ", row=1, col=1)

# Plot box
for k, r in enumerate([col for col in prog_df.columns if col not in ['sample', 'batch']]):

    fig.add_trace(
                  go.Box(x=prog_df.batch, 
                         y=prog_df[r],
                         visible=True if k == 0 else False,
                         customdata=prog_df.index.values,
                         hovertemplate="Cell barcode: %{customdata}",
                         marker_color='black',
                        ), 
                   row=1, col=2)
# fig.update_xaxes(showticklabels=False, row=2, col=1)
fig.update_yaxes(ticksuffix = "  ", row=1, col=2)

# Plot loadings
for k, r in enumerate(loadings.columns):

    dfp = loadings.sort_values(r, ascending=False)[:100]
    fig.add_trace(
                  go.Bar(x=dfp.index, 
                         y=dfp[r],
                         name='', 
                         orientation='v',
                         hovertemplate="Gene Name: %{x} <br> Loading: %{y}",
                         visible=True if k == 0 else False,
                         marker_color='black',
                        ), 
                   row=2, col=1)
fig.update_xaxes(showticklabels=False, row=2, col=1)
fig.update_yaxes(tickvals=np.arange(0,1.25,0.25), ticksuffix = "  ", row=2, col=1)
        
# Define buttons for dropdown
col_opts = list(loadings.columns)
buttons_opts = []
for i, opt in enumerate(col_opts):
    args = [False] * len(col_opts)
    args[i] = True
    buttons_opts.append(
        dict(
             method='update',
             label=opt,
             args=[{
                    'visible': args, #this is the key line!
                    'title': opt,
                    'showlegend': False
                   }]
            ))
    
fig.update_layout(height=700, width=1000, 
                  plot_bgcolor='whitesmoke',

                  title_text="GEP Dashboard - v0.1 - Investigate GEPs",
                  updatemenus = [go.layout.Updatemenu(
                                 active=0,
                                 buttons=buttons_opts,
                                 x=-0.1,
                                 xanchor='left',
                                 y=1,
                                 yanchor='bottom')]
                 )
fig.write_html('example_output/1_covarite_association.html')

In [None]:
# Plot trait enrichment interpretation
def plot_interactive_phewas(data, x_column='trait_reported',
                            y_column='-log10(p-value)',
                            color_column='trait_category',
                            filter_column='program_name',
                            significance_threshold=0.05,
                            annotation_cols=["program_name", "trait_reported",
                                             "trait_category", "P-value",
                                             "Genes", "study_id", "pmid"],
                           query_string="trait_category != 'measurement'",
                           title="Cell Program x OpenTargets GWAS L2G Enrichment"):
    
    # Get unique values for the filtering column
    filter_values = ['All'] + list(data[filter_column].unique())

    # Initialize output widget to display the plot
    output = Output()
    
    if query_string:
        data=data.query(query_string)

    # Function to update plot based on dropdown selection
    def update_plot(selected_value):
        # Filter data based on selected value
        if selected_value == "All":
            filtered_data = data.copy()  # No selection, show all data
        else:
            filtered_data = data[data[filter_column] == selected_value]

        # Create the plot
        fig = px.scatter(filtered_data, x=x_column, y=y_column, color=color_column,
                         title=title,
                         hover_data=annotation_cols)

        # Customize layout
        fig.update_layout(
            xaxis_title=x_column,
            yaxis_title=y_column,
            yaxis=dict(tickformat=".1f"),
            width=1000,  # Adjust width as needed
            height=800,  # Adjust height as needed,
            xaxis_tickfont=dict(size=4)
        )

        # Add horizontal dashed line for significance threshold
        fig.add_hline(y=-np.log10(significance_threshold), line_dash="dash",
                      annotation_text=f'Significance Threshold ({significance_threshold})', annotation_position="top right")

        # Clear previous plot and display the new one
        with output:
            output.clear_output(wait=True)
            fig.show()

    # Create dropdown widget
    dropdown = Dropdown(options=filter_values, description=f"{filter_column}:")

    # Define function to handle dropdown value change
    def on_change(change):
        if change['type'] == 'change' and change['name'] == 'value':
            update_plot(change['new'])

    # Link dropdown change to function
    dropdown.observe(on_change)

    # Display dropdown widget and initial plot
    display(VBox([dropdown, output]))