# Set-up

In [None]:
import os
import sys
import yaml
import logging
import mudata
import pandas as pd

# Change path to wherever you have repo locally
sys.path.append('/oak/stanford/groups/engreitz/Users/ymo/Tools/test/gene_network_evaluation')

from src.evaluation import (
    compute_categorical_association,
    compute_geneset_enrichment,
    compute_trait_enrichment,
    compute_perturbation_association,
    compute_explained_variance_ratio,
    compute_motif_enrichment
)
from src.evaluation.enrichment_trait import process_enrichment_data

# categorical association, geneset enrichment, GO term enrichment, trait enrichment

In [None]:
folder = "oak/stanford/groups/engreitz/Users/ymo/NMF_re-inplementing/Results/torch-cNMF_evaluation/100k_cells_10iter_torch_mu_batch"

for k in [30, 60, 80, 100, 200, 250, 300]: # 400

    os.makedirs(f"{folder}/Eval/{k}", exist_ok=True)

    # Load mdata
    mdata = mudata.read('/oak/stanford/groups/engreitz/Users/ymo/NMF_re-inplementing/Results/torch-cNMF_evaluation/100k_cells_10iter_torch_mu_batch/adata/cNMF_{}_2_0.h5mu'.format(k))

    # Run categorical assocation
    results_df, posthoc_df = compute_categorical_association(mdata, prog_key='cNMF', categorical_key='sample', 
                                                            pseudobulk_key=None, test='dunn', n_jobs=-1, inplace=False)

    results_df.to_csv('{}/{}_categorical_association_results.txt'.format(folder,k), sep='\t', index=False) # This was made wide form to insert into .var of the program anndata.
    posthoc_df.to_csv('{}/{}_categorical_association_posthoc.txt'.format(folder,k), sep='\t', index=False)

    '''
    # Run perturbation assocation
    for samp in mdata['rna'].obs['sample'].unique():
        mdata_ = mdata[mdata['rna'].obs['sample']==samp]
        test_stats_df = compute_perturbation_association(mdata_, prog_key='cNMF', 
                                                        collapse_targets=True,
                                                        pseudobulk=False,
                                                        reference_targets=('non-targeting'),
                                                        n_jobs=-1, inplace=False)

        test_stats_df.to_csv('{}/{}_perturbation_association_results_{}.txt'.format(folder,k,samp), sep='\t', index=False)
    '''

    # Gene-set enrichment
    pre_res = compute_geneset_enrichment(mdata, prog_key='cNMF', data_key='rna', prog_nam=None,
                                        organism='human', library='Reactome_2022', method="fisher",
                                        database='enrichr', loading_rank_thresh=300, n_jobs=-1, 
                                        inplace=False, user_geneset=None)
    pre_res.to_csv('{}/{}_geneset_enrichment.txt'.format(folder,k), sep='\t', index=False)

    # GO Term enrichment
    pre_res = compute_geneset_enrichment(mdata, prog_key='cNMF', data_key='rna', prog_nam=None,
                                        organism='human', library='GO_Biological_Process_2023', method="fisher",
                                        database='enrichr', loading_rank_thresh=300, n_jobs=-1, 
                                        inplace=False, user_geneset=None)
    pre_res.to_csv('{}/{}_GO_term_enrichment.txt'.format(folder,k), sep='\t', index=False)

    # Run trait enrichment
    pre_res_trait = compute_trait_enrichment(mdata, gwas_data='/oak/stanford/groups/engreitz/Users/ymo/Tools/cNMF_benchmarking/cNMF_benchmakring_pipeline/Evaluation/Resources/OpenTargets_L2G_Filtered.csv.gz', 
                                            prog_key='cNMF', prog_nam=None, data_key='rna', 
                                            library='OT_GWAS', n_jobs=-1, inplace=False, 
                                            key_column='trait_efos', gene_column='gene_name', 
                                            method='fisher', loading_rank_thresh=300)
    pre_res_trait.to_csv('{}/{}_trait_enrichment.txt'.format(folder,k), sep='\t', index=False)

    

# Motif Enrichment

In [None]:
# Thresholds
score_thresh_abc_e2g_enhancer = 0.015
score_thresh_abc_e2g_promoter = 0.8
fimo_thresh_enhancer = 1e-6
fimo_thresh_promoter = 1e-4

output_path = 'cNMF_100/'

In [None]:
# # Format files
# for i in range(4):
#     e2g = pd.read_csv('scE2G_links/EnhancerPredictionsAllPutative.ForVariantOverlap.shrunk150bp_D{}.tsv'.format(i), sep='\t')

#     e2g_enhancers = e2g.loc[(e2g['class']!='promoter') &\
#                             (e2g['ABC.Score']>score_thresh_abc_e2g_enhancer)]
#     e2g_enhancers = e2g_enhancers.loc[:,['chr', 'start', 'end', 'name', 'class', 'ABC.Score', 'TargetGene']]
#     e2g_enhancers.columns = [        'chromosome', 'start', 'end', 'seq_name', 'seq_class', 'seq_score', 'gene_name']

#     e2g_enhancers.to_csv('scE2G_links/EnhancerPredictionsAllPutative.ForVariantOverlap.shrunk150bp_D{}_enhancer.tsv'.format(i), 
#                           sep='\t', index=False)

#     e2g_promoters = e2g.loc[(e2g['class']=='promoter') &\
#                             (e2g['ABC.Score']>score_thresh_abc_e2g_promoter)]
#     e2g_promoters = e2g_promoters.loc[:,['chr', 'start', 'end', 'name', 'class', 'ABC.Score', 'TargetGene']]
#     e2g_promoters.columns = ['chromosome', 'start', 'end', 'seq_name', 'seq_class', 'seq_score', 'gene_name']

#     e2g_promoters.to_csv('scE2G_links/EnhancerPredictionsAllPutative.ForVariantOverlap.shrunk150bp_D{}_promoter.tsv'.format(i), 
#                         sep='\t', index=False)

In [None]:
# Load program data
mdata = mudata.read('../../shared/240810_ipsc_ec_cNMF_analysis/cNMF_100_0.2_gene_names.h5mu')
mdata

In [None]:
# # Run in script

# # Run motif enrichment and save results
# os.makedirs(output_path, exist_ok=True)
# for i in range(4):
#     for class_, thresh in [('enhancer', fimo_thresh_enhancer), 
#                            ('promoter', fimo_thresh_promoter)]:

#         loci_file = 'scE2G_links/EnhancerPredictionsAllPutative.ForVariantOverlap.shrunk150bp_D{}_{}.tsv'.format(i, class_)

#         motif_match_df, motif_count_df, motif_enrichment_df = compute_motif_enrichment(
#             mdata, 
#             prog_key='cNMF',
#             data_key='rna',
#             motif_file='../../../gene_program_evaluation/gene_network_evaluation/smk/resources/hocomoco_meme.meme',
#             seq_file='../../../../data/hg38/hg38.fa',
#             loci_file=loci_file,
#             window=1000,
#             sig=thresh,
#             eps=1e-4,
#             n_top=2000,
#             n_jobs=-1,
#             inplace=False
#         )

#         motif_match_df.to_csv(os.path.join(output_path, f'cNMF_{class_}_pearson_topn2000_sample_D{i}_motif_match.txt'), sep='\t', index=False)
#         motif_count_df.to_csv(os.path.join(output_path, f'cNMF_{class_}_pearson_topn2000_sample_D{i}_motif_count.txt'), sep='\t', index=False)
#         motif_enrichment_df.to_csv(os.path.join(output_path, f'cNMF_{class_}_pearson_topn2000_sample_D{i}_motif_enrichment.txt'), sep='\t', index=False)

In [None]:
# Import motif enrichments
motif_enrichment_data = []
for fil in os.listdir('cNMF_100'):
    if 'motif_enrichment' in fil:
        motif_enrichment_data_ = pd.read_csv('cNMF_100/{}'.format(fil), sep='\t')
        motif_enrichment_data_['sample'] = fil.split('sample_')[-1].split('_motif')[0]
        if 'promoter' in fil:
            motif_enrichment_data_['class'] = 'promoter'
        elif 'enhancer' in fil:
            motif_enrichment_data_['class'] = 'enhancer'            

        motif_enrichment_data.append(motif_enrichment_data_)

motif_enrichment_data = pd.concat(motif_enrichment_data, ignore_index=True)
# motif_enrichment_data.to_csv('cNMF_100_motif_enrichment.txt', sep='\t')
# motif_enrichment_data.loc[motif_enrichment_data['class']=='promoter'].to_csv('cNMF_100_motif_enrichment_promoter.txt', sep='\t')
# motif_enrichment_data.loc[motif_enrichment_data['class']=='enhancer'].to_csv('cNMF_100_motif_enrichment_enhancer.txt', sep='\t')

In [None]:
# Make summary table column
for seq_class in ['promoter', 'enhancer']:
    motif_enrichment_data_ = motif_enrichment_data.loc[motif_enrichment_data['class']==seq_class]
    motif_enrichment_data_ = motif_enrichment_data_.groupby(['sample', 'program_name', 'motif']).min().reset_index()

    motif_summary_data = pd.DataFrame(index=motif_enrichment_data_.program_name.unique(), columns=['top10_motifs', 'num_enriched_motifs'])
    for prog in motif_enrichment_data_.program_name.unique():
        motif_summary_data_ = motif_enrichment_data_.loc[(motif_enrichment_data_.program_name==prog) & ((motif_enrichment_data_.adj_pval<=0.05))]
        motif_summary_data_ = motif_summary_data_.sort_values('stat', ascending=False).head(10).motif.values
        motif_summary_data.loc[prog,'top10_motifs'] = ', '.join(motif_summary_data_.tolist())
        motif_summary_data.loc[prog, 'num_enriched_motifs'] = motif_enrichment_data_.loc[(motif_enrichment_data_.program_name==prog) & ((motif_enrichment_data_.adj_pval<=0.05))].shape[0]
        
    motif_summary_data.to_csv('motif_summary_data_{}.txt'.format(seq_class), sep='\t')

In [None]:
# # Copy out to dashboard
# out_dir = '../../shared/250110_ipsc_ec_dashboard_setup/cNMF_100/'
# for fil in os.listdir('cNMF_100'):
#     if 'motif_' in fil:
#         class_ = fil.split('_')[1]
#         name = fil.split('_sample_')[1]
#         new_nam = f'cNMF_100_{class_}_test_pearsonr_sample_{name}'
#         if not name.startswith('D0'):
#             new_nam = new_nam.replace('sample_', 'sample_sample_')
#         os.system('cp {} {}'.format(os.path.join('cNMF_100', fil),
#                                     os.path.join(out_dir, new_nam)))