# Set-up

In [75]:
import os
import sys
import yaml
import logging
import mudata
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from joblib import Parallel, delayed
from typing import List, Dict, Tuple, Union, Optional, Literal

from scipy import stats
from statsmodels.stats.multitest import multipletests

# Change path to wherever you have repo locally
sys.path.append('/cellar/users/aklie/opt/gene_program_evaluation')

from src.evaluation import (
    compute_categorical_association,
    compute_geneset_enrichment,
    compute_trait_enrichment,
    compute_perturbation_association,
    compute_explained_variance_ratio,
    compute_motif_enrichment
)
from src.evaluation.enrichment_trait import process_enrichment_data

In [2]:
# I/O paths
path_config = "/cellar/users/aklie/opt/gene_program_evaluation/examples/evaluation/iPSC_EC/cNMF_30/evaluation_pipeline.yml"
config = yaml.safe_load(open(path_config))

## I/O

In [3]:
io_config = config['io']
io_config

{'path_mdata': '/cellar/users/aklie/opt/gene_program_evaluation/examples/inference/iPSC_EC/cNMF/cNMF_30_0.2_gene_names.h5mu',
 'path_out': '/cellar/users/aklie/opt/gene_program_evaluation/examples/evaluation/iPSC_EC/cNMF_30',
 'data_key': 'rna',
 'prog_key': 'cNMF'}

In [4]:
# Load mdata
path_mdata = io_config['path_mdata']
mdata = mudata.read(path_mdata)
mdata

  utils.warn_names_duplicates("var")


In [5]:
prog_key = "cNMF"
data_key = "rna"

In [6]:
# choose the first 3 programs in prog_key
prog_names = list(mdata.mod[prog_key].var_names)[:3]
mdata.mod[prog_key] = mdata.mod[prog_key][:, prog_names]
mdata

# Perturbation association testing

## Read MuData

In [43]:
def read_mudata(
    mdata: Union[os.PathLike, mudata.MuData],
    inplace: bool = True,
    **kwargs
) -> mudata.MuData:
    """
    Read in the mudata object from a path or directly from the object itself.

    Parameters
    ----------
    mdata : Union[str, mudata.MUData]
        Path to the mudata object or the mudata object itself.
    inplace : bool, optional
        Whether to load the mudata object in place, by default True.

    Returns
    -------
    mudata.MUData
        The mudata object.
    """
    # Read in mudata if it is provided as a path
    frompath = False
    if isinstance(mdata, str):
        if os.path.exists(mdata):
            mdata = mudata.read(mdata)
            if inplace:
                logging.warning('Changed to inplace=False since path was provided')
                inplace=False
            frompath=True
        else: raise ValueError('Incorrect mudata specification.')

    if not inplace and not frompath:
        mdata = mudata.MuData({prog_key: mdata[prog_key].copy()})

    return mdata

In [44]:
mdata = read_mudata(mdata, inplace=False)
mdata

## Get guide metadata

In [45]:
def get_guide_metadata(
    mdata: mudata.MuData,
    prog_key: str,
    guide_names_key: str = 'guide_names',
    guide_targets_key: str = 'guide_targets'
) -> pd.DataFrame:
    """
    Get guide metadata from the mudata object.

    Parameters
    ----------
    mdata : mudata.MuData
        The mudata object.
    prog_key : str
        The key of the program in the mudata object.

    Returns
    -------
    pd.DataFrame
        The guide metadata.
    """
    guide_metadata = pd.DataFrame(index=mdata[prog_key].uns[guide_names_key], columns=['Target'])
    guide_metadata['Target'] = mdata[prog_key].uns[guide_targets_key]

    return guide_metadata

In [25]:
get_guide_metadata(mdata, prog_key, 'guide_names', 'guide_targets')

Unnamed: 0,Target
ACAA1_-_38178575.23-P1P2,ACAA1
ACAA1_+_38178488.23-P1P2,ACAA1
ACAA1_+_38178517.23-P1P2,ACAA1
ACAA1_+_38178559.23-P1P2,ACAA1
ACAA1_+_38178570.23-P1P2,ACAA1
...,...
NON-TARGETING_00324,non-targeting
NON-TARGETING_01858,non-targeting
NON-TARGETING_02288,non-targeting
NON-TARGETING_02543,non-targeting


## Compute perturbation associations

In [97]:
def compute_perturbation_association_(
    test_data: mudata.MuData,
    reference_data: mudata.MuData,
    program: str,
    level_name: str,
    test_stats_df: List[List],
    balanced: bool = False
):

    # Get perturbation data
    test_data_ = test_data[:, program].X.toarray()

    # TODO: Resample reference data to match number of obs
    reference_data_ = reference_data[:, program].X.toarray()
    if balanced:
        reference_data_ = reference_data_[np.random.choice(reference_data_.shape[0], test_data_.shape[0], replace=False)]

    # Calculate log2FC
    ref_mean = np.mean(reference_data_)
    test_mean = np.mean(test_data_)
    log2fc = np.log2(test_mean / ref_mean)

    # Compute Mann-Whitney U test
    results = stats.mannwhitneyu(test_data_, reference_data_)

    # Append to test stats df
    test_stats_df.append([level_name, program, ref_mean, test_mean, log2fc, results[0][0], results[1][0]])

## Wrapper

In [110]:
# TODO: Add support for stratification by categorical levels
def compute_perturbation_association(
    mdata: Union[os.PathLike, mudata.MuData],
    prog_key: str,
    guide_names_key: str = 'guide_names',
    guide_targets_key: str = 'guide_targets',
    guide_assignments_key: str = 'guide_assignment',
    collapse_targets: bool = True,
    pseudobulk: bool = False,
    reference_targets: Union[str, List[str]] = ['non-targeting'],
    balanced: bool = False,
    n_jobs: int = 1,
    inplace: bool = True
):
    """Compute Mann-Whitney U test for cells targeted with a perturbation against a reference set of cells.

    Parameters
    ----------
    mdata : MuData
        mudata object containing anndata of program scores and cell-level metadata.
    prog_key : str
        index for the gene program anndata object (mdata[prog_key]) in the mudata object.
    guide_names_key : str (default: 'guide_names')
        key in mdata[prog_key].uns for 1D np.array of guide names.
    guide_targets_key : str (default: 'guide_targets')
        key in mdata[prog_key].uns for 1D np.array of target gene names. Should be in the same order as guide_names.
    guide_assignments_key : str (default: 'guide_assignment')
        key in mdata[prog_key].obsm for 2D cell x guide assignment matrix. Should correspond to guide_names and guide_targets.
    collapse_targets : bool (default:True)
        If target gene per guide is provided, perform tests on target levels. 
        Mutually exclusive with pseudobulk.
    pseudobulk : bool (default: False)
        If multiple non-targeting guides are available - optionally test at pseudobulk level.
        Mutually exclusive with collapse_targets.
    reference_targets : tuple
        List of target values to use as reference distribution.
    balanced : bool (default: False)
        If True, resample reference data to match number of observations in test
    n_jobs: int (default: 1)
        number of threads to run processes on.
    inplace: Bool (default: True)
        update the mudata object inplace or return a copy

    Returns
    -------
    if inplace:
        UPDATES
        mdata[prog_key].varm['perturbation_association_pval']
        mdata[prog_key].varm['perturbation_association_stat'] 
    else:
        RETURNS
        returns test_stats_df
    """
    # Read in mudata if it is provided as a path
    frompath=False
    if isinstance(mdata, str):
        if os.path.exists(mdata):
            mdata = mudata.read(mdata)
            if inplace:
                logging.warning('Changed to inplace=False since path was provided')
                inplace=False
            frompath=True
        else: raise ValueError('Incorrect mudata specification.')
    if not inplace and not frompath:
        mdata = mudata.MuData({prog_key: mdata[prog_key].copy()})

    # Guide metadata
    guide_metadata = get_guide_metadata(
        mdata, 
        prog_key, 
        guide_names_key=guide_names_key, 
        guide_targets_key=guide_targets_key
    )

    # Run tests on pseudobulk if multiple non-targeting guides otherwise at single-cell level
    if collapse_targets:
        if pseudobulk:
            pseudobulk=False
            logging.info('Setting pseudobulk to False since collapse_targets is True')

    if pseudobulk:
        pseudobulked_scores = pd.DataFrame(index=mdata[prog_key].uns[guide_names_key], columns=mdata[prog_key].obs_names)

        #TODO: Parallelize
        for i, guide in enumerate(mdata[prog_key].uns['guide_names']):
            for program in mdata[prog_key].obs_names:
                pseudobulked_scores.loc[guide, program] = \
                mdata[prog_key][mdata[prog_key].obsm[guide_assignments_key][:,i].astype(bool), program].X.mean()
        
        pseudobulked_scores['Target'] = mdata[prog_key].uns['guide_targets']

    # Create reference data
    if type(reference_targets) is str:
        reference_targets = [reference_targets]

    # If pseudobulk, grab a DataFrame of all non-targeting guides
    if pseudobulk:
        reference_data = pseudobulked_scores.loc[pseudobulked_scores.Target.isin(reference_targets)]

    # Otherwise, grab a MuData for all those guides assigned to the reference targets
    else:
        reference_guide_idx = guide_metadata.index.get_indexer(guide_metadata.loc[guide_metadata.Target.isin(reference_targets)].index.values)
        reference_data = mdata[prog_key][np.any(mdata[prog_key].obsm[guide_assignments_key][:,reference_guide_idx].astype(bool), axis=1)]

    # Get the rest of the guides to test against the reference
    test_guides = guide_metadata.loc[~guide_metadata.Target.isin(reference_targets)].index.values

    # Run tests
    test_stats_df = []

    # If collapse_targets, run tests for each target
    if collapse_targets:

        # Set level key
        level_key = 'target'

        # Run tests for each target (with multiple guides)
        targets_no_ref = [targ for targ in guide_metadata['Target'].unique() if targ not in reference_targets]
        for target in tqdm(targets_no_ref, desc='Testing perturbation association', unit='targets'):

            # Grab data for the current target
            test_guide_idx = guide_metadata.index.get_indexer(guide_metadata.index[guide_metadata['Target']==target])
            test_data = mdata[prog_key][mdata[prog_key].obsm[guide_assignments_key][:,test_guide_idx].any(-1).astype(bool)]

            # Run test for every program
            Parallel(n_jobs=n_jobs, backend='threading')(delayed(compute_perturbation_association_)(
                test_data, 
                reference_data, 
                program, 
                target, 
                test_stats_df,
                balanced) for program in mdata[prog_key].var_names)

    else:

        # Set level key
        level_key='guide'

        # Run tests for each guide
        for guide in tqdm(test_guides, desc='Testing perturbation association', unit='guides'):
            
            # Run test at using pseudobulks at guide level -> requires multiple non-targeting guides as reference
            if pseudobulk:
                # TODO: Implement per guide probability under reference approach
                raise NotImplementedError()
            else:
                # Grab data for the current guide
                test_guide_idx = guide_metadata.index.get_loc(guide)
                test_data = mdata[prog_key][mdata[prog_key].obsm[guide_assignments_key][:,test_guide_idx].astype(bool)]

                # Run test for every program
                Parallel(n_jobs=n_jobs, backend='threading')(delayed(compute_perturbation_association_)(
                    test_data, 
                    reference_data, 
                    program, 
                    guide, 
                    test_stats_df,
                    balanced) for program in mdata[prog_key].var_names)
    
    # Create DataFrame
    test_stats_df = pd.DataFrame(test_stats_df, columns=['{}_name'.format(level_key), 'program_name', 'ref_mean', 'test_mean', 'log2FC', 'stat', 'pval'])

    # Correct for multiple testing
    test_stats_df['adj_pval'] = multipletests(test_stats_df['pval'], method='fdr_bh')[1]

    # Return only the evaluations if not in inplace mode
    if not inplace: return test_stats_df
    else:
        init_array = np.empty(mdata[prog_key].shape[1], len(test_stats_df['{}_name'.format(level_key)].unique()))
        init_array[:] = np.nan

        stats, pvals = init_array.copy(), init_array.copy()

        for level_idx, level_name in enumerate(test_stats_df['{}_name'.format(level_key)].unique()):
            stats[:, level_idx] = test_stats_df.loc[test_stats_df['{}_name'.format(level_key)]==level_name, 'stat'][mdata[prog_key].var_names].values
            pvals[:, level_idx] = test_stats_df.loc[test_stats_df['{}_name'.format(level_key)]==level_name, 'pval'][mdata[prog_key].var_names].values
        mdata[prog_key].varm['perturbation_association_{}_stat'.format(level_key)] = stats
        mdata[prog_key].varm['perturbation_association_{}_pval'.format(level_key)] = pvals
        mdata[prog_key].uns['perturbation_association_{}_names'.format(level_key)] = test_stats_df['{}_name'.format(level_key)].unique().values

In [104]:
# Run wrapper function
perturbation_assocation_df = compute_perturbation_association(
    mdata, 
    prog_key=prog_key,
    collapse_targets=perturbation_assocation_config['collapse_targets'],
    pseudobulk=perturbation_assocation_config['pseudobulk'],
    reference_targets=perturbation_assocation_config['reference_targets'],
    n_jobs=perturbation_assocation_config['n_jobs'],
    inplace=perturbation_assocation_config['inplace'],
    balanced=True
)

Testing perturbation association:   0%|          | 0/298 [00:00<?, ?targets/s]

['ACAA1', '2', 0.07231384822023985, 0.0727679635493973, 0.009031484580391644, 1311736.5, 0.7015746363089513]


In [107]:
perturbation_assocation_df.sort_values("adj_pval")

Unnamed: 0,target_name,program_name,ref_mean,test_mean,log2FC,stat,pval,adj_pval
518,SETDB1,2,0.072214,0.105496,0.546854,1062804.0,6.742828e-26,6.028088e-23
395,PLCG1,1,0.059967,0.023381,-1.358855,2267287.0,9.445172e-15,4.221992e-12
486,RIBC1,2,0.070734,0.090998,0.363415,2146274.0,2.613326e-14,7.787712e-12
86,CBLL1,1,0.065030,0.113998,0.809829,1819187.5,5.193490e-14,1.160745e-11
638,TAF2,1,0.069444,0.031473,-1.141732,681832.5,4.742256e-13,8.479154e-11
...,...,...,...,...,...,...,...,...
812,CYC1,2,0.074483,0.075100,0.011901,1726766.5,9.832281e-01,9.957998e-01
729,ZC3H8,0,0.085446,0.079813,-0.098405,550387.0,9.887068e-01,9.957998e-01
62,BRD3,2,0.073512,0.073588,0.001488,2566740.5,9.884043e-01,9.957998e-01
838,LUZP6,1,0.069368,0.068850,-0.010803,172340.5,9.918289e-01,9.957998e-01


In [108]:
perturbation_assocation_df.sort_values("adj_pval")

Unnamed: 0,target_name,program_name,ref_mean,test_mean,log2FC,stat,pval,adj_pval
518,SETDB1,2,0.072214,0.105496,0.546854,1062804.0,6.742828e-26,6.028088e-23
395,PLCG1,1,0.059967,0.023381,-1.358855,2267287.0,9.445172e-15,4.221992e-12
486,RIBC1,2,0.070734,0.090998,0.363415,2146274.0,2.613326e-14,7.787712e-12
86,CBLL1,1,0.065030,0.113998,0.809829,1819187.5,5.193490e-14,1.160745e-11
638,TAF2,1,0.069444,0.031473,-1.141732,681832.5,4.742256e-13,8.479154e-11
...,...,...,...,...,...,...,...,...
812,CYC1,2,0.074483,0.075100,0.011901,1726766.5,9.832281e-01,9.957998e-01
729,ZC3H8,0,0.085446,0.079813,-0.098405,550387.0,9.887068e-01,9.957998e-01
62,BRD3,2,0.073512,0.073588,0.001488,2566740.5,9.884043e-01,9.957998e-01
838,LUZP6,1,0.069368,0.068850,-0.010803,172340.5,9.918289e-01,9.957998e-01


# Categorical association testing

In [111]:
categorical_association_config = config['categorical_association']
categorical_association_config

{'categorical_key': 'sample',
 'pseudobulk_key': None,
 'test': 'dunn',
 'n_jobs': -1,
 'inplace': False}

In [117]:
from scipy import stats, sparse

In [118]:
categorical_key = categorical_association_config['categorical_key']
pseudobulk_key = categorical_association_config['pseudobulk_key']


In [119]:
prog_nam="0"

In [120]:
samples = []
if pseudobulk_key is None:
    
    # Run with single-cells are individual data points
    for category in mdata[prog_key].obs[categorical_key].astype(str).unique():
        sample_ = mdata[prog_key][mdata[prog_key].obs[categorical_key].astype(str)==category, prog_nam].X[:,0]
        if sparse.issparse(sample_):
            sample_ = sample_.toarray().flatten()
        samples.append(sample_)

In [122]:
stat, pval = stats.kruskal(*samples, nan_policy='propagate')
stat, pval

(68816.47038134227, 0.0)

In [139]:
# Create a single shuffled version of categorical_key
mdata[prog_key].obs[f"shuffled_{categorical_key}"] = mdata[prog_key].obs[categorical_key].sample(frac=1).values

# Recalculate the p-value with the shuffled version
samples = []
for category in mdata[prog_key].obs[f"shuffled_{categorical_key}"].astype(str).unique():
    sample_ = mdata[prog_key][mdata[prog_key].obs[f"shuffled_{categorical_key}"].astype(str)==category, prog_nam].X[:,0]
    if sparse.issparse(sample_):
        sample_ = sample_.toarray().flatten()
    samples.append(sample_)

In [140]:
stat, pval = stats.kruskal(*samples, nan_policy='propagate')
stat, pval

(0.8442081296082438, 0.8388663900924934)

In [121]:
samples

[ArrayView([0.30733788, 0.23179917, 0.29495603, ..., 0.17720768,
            0.30303675, 0.21068594]),
 ArrayView([0.        , 0.        , 0.        , ..., 0.02953346,
            0.02114811, 0.06008436]),
 ArrayView([0.00582987, 0.00041388, 0.        , ..., 0.00470786,
            0.02357423, 0.        ]),
 ArrayView([0., 0., 0., ..., 0., 0., 0.])]

# DONE!

---