In [1]:
import sys
import os

import pandas as pd
import numpy as np
import scanpy as sc

from sklearn.cluster import KMeans
from scipy.stats import hypergeom
from sklearn.metrics import pairwise_distances
from itertools import combinations


import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams.update({'axes.labelsize' : 'large',
                     'pdf.fonttype':42
                    }) 
from matplotlib.backends.backend_pdf import PdfPages

from tqdm import tqdm
from tqdm.contrib.concurrent import process_map

import multiprocessing as mp

import gc
import warnings
import time
import pickle
import json
import math

import torch

from importlib import reload
import util_functions
import energy_distance_calc

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sc.__version__

'1.9.8'

In [3]:
!nvidia-smi

Mon Apr  7 15:58:12 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla V100-PCIE-32GB           Off |   00000000:3B:00.0 Off |                    0 |
| N/A   28C    P0             25W /  250W |     164MiB /  32768MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [4]:
!free -m

              total        used        free      shared  buff/cache   available
Mem:         385430        9606      374351          44        1471      374562
Swap:        131071        1295      129776


<h3>Load PCA matrix and gRNA-Cell name dictionary</h3>

In [5]:
json_fp = "./config.json"
with open(json_fp, 'r') as fp:
    config = json.load(fp)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

annotation_df = pd.read_csv(os.path.join(config["output_file_name_list"]["OUTPUT_FOLDER"],
                                         config["output_file_name_list"]["annotation_file"]),index_col=0)

In [6]:
(pca_df,gRNA_dict) = util_functions.load_files(config["input_data"]["h5ad_file"],
                                               config["input_data"]["sgRNA_file"],
                                               os.path.join(config["output_file_name_list"]["OUTPUT_FOLDER"],
                                                            config["output_file_name_list"]["pca_table"]),
                                               os.path.join(config["output_file_name_list"]["OUTPUT_FOLDER"],
                                                            config["output_file_name_list"]["gRNA_dict"]),
                                               overwrite=False
                                              )

--- Processing PCA data ---
PCA file './pipeline_output/pca_dataframe.pickle' not found. Loading PCA from AnnData.
Loading input file '/project/shared/gcrb_igvf/data/shared/TF_Perturbseq_full/H5AD/V0.1_TF_Perturbseq_full_sgRNACells_filtered_w_embedding_full_dataset.h5ad'...
Extracting PCA data from 'X_pca'...
Saving PCA data to './pipeline_output/pca_dataframe.pickle'...
Releasing AnnData object from memory.

--- Processing gRNA dictionary ---
Dictionary file './pipeline_output/gRNA_dictionary.pickle' not found. Creating dictionary from sgRNA data.
Loading sgRNA file '/project/GCRB/Hon_lab/s215194/Single_Cell/TF_perturbseq_full/_all_lanes_combined/aggr_dataframe/aggr_combined_df_full.pkl'...
Transposing sgRNA data and filtering by cells present in input data...
Extracting non-zero count cell-gRNA pairs...
Creating gRNA dictionary (improved method)...


Processing gRNA pairs: 100%|██████████| 1394293/1394293 [00:00<00:00, 1899439.31it/s]


gRNA dictionary creation complete. Found 11634 types of gRNAs.
Saving gRNA dictionary to './pipeline_output/gRNA_dictionary.pickle'...
Releasing sgRNA data from memory.

--- Processing finished ---


In [7]:
annotation_df

Unnamed: 0,protospacer_ID,target_transcript_name,target_gene_name,source,protospacer,reverse_compliment
0,DNAJC19_ B,DNAJC19,DNAJC19,pos_control,GGGAACTCCTGTAAGGTCAG,CTGACCTTACAGGAGTTCCC
1,POLR1D_ B,POLR1D,POLR1D,pos_control,GGGAAGCAAGGACCGACCGA,TCGGTCGGTCCTTGCTTCCC
2,OR5K2-2,OR5K2,OR5K2,neg_control,GAAAAAATTGTAGAGGAATA,TATTCCTCTACAATTTTTTC
3,SP1_+_53773993.23-P1P2-1,SP1:P1P2,SP1,target,GAAAAACGCGGACGCTGACG,CGTCAGCGTCCGCGTTTTTC
4,SP8_-_20826141.23-P1P2-2,SP8:P1P2,SP8,target,GAAAAAGATCCTCTGAGAGG,CCTCTCAGAGGATCTTTTTC
...,...,...,...,...,...,...
14353,ZNF532_-_56532303.23-P1-2,ZNF532:P1,ZNF532,target,GTTTTGGCTGCCATGAAGGG,CCCTTCATGGCAGCCAAAAC
14354,ZNF829_-_37406927.23-P1P2-2,ZNF829:P1P2,ZNF829,target,GTTTTGGTCCCCAGGAGAAC,GTTCTCCTGGGGACCAAAAC
14355,NANOG_+_7942459.23-P1P2-2,NANOG:P1P2,NANOG,target,GTTTTTCCATTATAACTTGG,CCAAGTTATAATGGAAAAAC
14356,OR8B3-5,OR8B3,OR8B3,neg_control,GTTTTTGTCTTCAAAAATCT,AGATTTTTGAAGACAAAAAC


In [8]:
gRNA_region_dict = {}
count_region_dict = {}

for index,row in annotation_df.iterrows():
    if row.protospacer_ID in gRNA_dict.keys():
        if row.target_transcript_name in gRNA_region_dict.keys():
            gRNA_region_dict[row.target_transcript_name] += [row.protospacer_ID]
        else:
            gRNA_region_dict[row.target_transcript_name] = [row.protospacer_ID]
            
for key in gRNA_region_dict.keys():
    gRNA_region_dict[key] = np.unique(gRNA_region_dict[key])
    count_region_dict[key] = len(gRNA_region_dict[key])

In [9]:
len(gRNA_region_dict.keys())

2228

<h3>Calculate e-distance between gRNA per target</h3>

In [10]:
result_df_dict = {}

In [11]:
def generate_sgrna_group_combinations(gRNA_list_target, combi_count):
    """
    Generates unique pairs of sgRNA groups based on the input list.

    Handles two cases based on the total number of gRNAs:
    - If > 6: Uses a fixed combination size ('combi_count').
    - If <= 6: Generates combinations of varying sizes.

    Args:
        gRNA_list_target (np.array): Array of sgRNA names for the current target region.
        combi_count (int): The number of items to choose when total gRNAs > 6.

    Returns:
        list: A list of unique tuples, where each tuple contains two lists
              representing a pair of sgRNA groups. e.g., [(['gA1'], ['gA2', 'gA3']), ...]
    """
    total_combis_tmp = []
    total_gRNA_num = len(gRNA_list_target)
    gRNA_list_target_tuple = tuple(gRNA_list_target) # Use tuple for combinations

    if total_gRNA_num > 6:
        # Case 1: More than 6 gRNAs, use fixed combination size 'combi_count'
        if total_gRNA_num < combi_count:
            print(f"Warning: total_gRNA_num ({total_gRNA_num}) < combi_count ({combi_count}). Skipping combination generation for this target.")
            return []
        # Generate combinations of size 'combi_count'
        for combis_tmp_tuple in combinations(gRNA_list_target_tuple, combi_count):
            combis_tmp_set = set(combis_tmp_tuple)
            # Split these 'combi_count' gRNAs into two non-empty groups
            for first_group_num in range(1, combi_count):
                for first_group_tuple in combinations(combis_tmp_tuple, first_group_num):
                    # Efficiently find the second group using set difference
                    second_group_set = combis_tmp_set - set(first_group_tuple)
                    if second_group_set: # Ensure the second group is not empty
                        total_combis_tmp.append((list(first_group_tuple), list(second_group_set)))

    else:
        # Case 2: 6 or fewer gRNAs, generate all pairs of non-empty, disjoint subsets
        for num_count_1 in range(1, total_gRNA_num): # Size of the first group
            for group1_tuple in combinations(gRNA_list_target_tuple, num_count_1):
                # Find remaining gRNAs efficiently
                remaining_gRNAs = tuple(g for g in gRNA_list_target_tuple if g not in group1_tuple)
                if not remaining_gRNAs: # Need remaining gRNAs to form the second group
                    continue
                # Generate the second group from the remaining gRNAs
                for num_count_2 in range(1, len(remaining_gRNAs) + 1): # Size of the second group
                     for group2_tuple in combinations(remaining_gRNAs, num_count_2):
                            total_combis_tmp.append((list(group1_tuple), list(group2_tuple)))

    # Ensure uniqueness of pairs (consider order within pair irrelevant for uniqueness check initially)
    # Convert lists to tuples for hashing, sort internal lists first, then the pair
    # E.g. (['B', 'A'], ['C']) becomes (('A', 'B'), ('C',))
    # E.g. (['C'], ['A', 'B']) becomes (('A', 'B'), ('C',))
    seen_combinations = set()
    unique_combis = []
    for group1, group2 in total_combis_tmp:
        tuple1 = tuple(sorted(group1))
        tuple2 = tuple(sorted(group2))
        # Sort the pair of tuples to handle cases like (g1, g2) vs (g2, g1) being the same split
        frozen_pair = tuple(sorted((tuple1, tuple2)))
        if frozen_pair not in seen_combinations:
            seen_combinations.add(frozen_pair)
            # Add back in the original list format
            unique_combis.append((group1, group2))

    # Note: The original code called util_functions.get_unique_list AFTER this.
    # If get_unique_list does something more complex than the uniqueness logic here,
    # you might need to adjust or call it after this function.
    # Assuming the logic above correctly finds unique pairs.
    return unique_combis


def get_cells_for_sgrna_groups(group1, group2, gRNA_dict):
    """
    Retrieves and combines unique cell identifiers for two groups of sgRNAs.

    Args:
        group1 (list): List of sgRNA names in the first group.
        group2 (list): List of sgRNA names in the second group.
        gRNA_dict (dict): Dictionary mapping sgRNA names to lists/arrays of cell identifiers.

    Returns:
        tuple: (np.array, np.array) containing unique cell identifiers for group1 and group2.
               Returns (None, None) if a gRNA name is not found in gRNA_dict.
    """
    try:
        cells1 = np.unique(np.concatenate([gRNA_dict[name] for name in group1])) if group1 else np.array([])
        cells2 = np.unique(np.concatenate([gRNA_dict[name] for name in group2])) if group2 else np.array([])
        return cells1, cells2
    except KeyError as e:
        print(f"Error: sgRNA name {e} not found in gRNA_dict.")
        return None, None
    except Exception as e:
        print(f"An error occurred retrieving cells: {e}")
        return None, None
    
def calculate_distance(X, cell_test1, cell_test2, device):
    """
    Calculates the distance between two cell groups using permutation_test.
    Attempts GPU calculation first, falls back to CPU if memory error occurs.

    Args:
        X (np.array or similar): The data matrix (e.g., PCA coordinates).
        cell_test1 (np.array): Array of cell identifiers for the first group.
        cell_test2 (np.array): Array of cell identifiers for the second group.
        device (str): The primary device to use ('cuda', 'cpu', etc.).

    Returns:
        float: The calculated distance, or -1.0 if calculation fails.
    """
    if cell_test1 is None or cell_test2 is None: # Check if cell retrieval failed
        return -1.0
    if len(cell_test1) == 0 or len(cell_test2) == 0:
        # Handle cases where one group has no cells (might happen if gRNA mapping is empty)
        # Distance is ill-defined or could be considered infinite/maximal. Return -1 as indicator.
        # print(f"Warning: Empty cell group for combination {index_combi}. Skipping distance calculation.")
        return -1.0

    mode = "GPU" if "cuda" in str(device) else "CPU" # Initial mode assumption
    try:
        # Assuming permutation_test takes numpy arrays for cell indices/names mapped to X
        # Ensure X is indexed correctly based on cell_test1/cell_test2 content
        # The original call used cell names/indices directly, assuming X is indexed accordingly.
        obs_edist = energy_distance_calc.permutation_test(X, cell_test1, cell_test2, device,
                                                     1, 1, return_permute=False).cpu()
        return obs_edist.item()
    except Exception as e_gpu: # Catch specific OOM error if possible, otherwise generic Exception
        # Check if it looks like an Out Of Memory error (heuristic)
        if 'memory' in str(e_gpu).lower():
            mode = "OOM -> CPU"
            try:
                # Fallback to CPU
                obs_edist = energy_distance_calc.permutation_test(X, cell_test1, cell_test2, "cpu",
                                                             1, 1, return_permute=False).cpu()
                return obs_edist.item()
            except Exception as e_cpu:
                # Handle CPU failure (e.g., data still too large even for CPU RAM)
                print(f"\nError: CPU calculation failed for combi {index_combi+1} after GPU OOM. Data too large? Error: {e_cpu}")
                return -1.0
        else:
            # Handle other non-OOM GPU errors
            print(f"\nError: GPU calculation failed for combi {index_combi+1} (not OOM). Error: {e_gpu}")
            return -1.0


def format_results_dataframe(res, total_combis, gRNA_list_target):
    """
    Creates a DataFrame from the distance results, sorts it, adds ranks,
    and includes boolean flags for sgRNA membership in each combination.

    Args:
        res (list): List of calculated distances (-1.0 indicates failure).
        total_combis (list): List of tuples, each containing two lists (sgRNA groups).
        gRNA_list_target (np.array): Array of all sgRNA names for the current target.

    Returns:
        pd.DataFrame: A formatted DataFrame with results.
    """
    if len(res) != len(total_combis):
        print(f"Warning: Mismatch between results count ({len(res)}) and combinations count ({len(total_combis)}). DataFrame might be incorrect.")
        # Attempt to proceed with the shorter length
        min_len = min(len(res), len(total_combis))
        res = res[:min_len]
        total_combis = total_combis[:min_len]

    result_df = pd.DataFrame({
        "e_dist": res,
        "combis": total_combis
    })

    # Filter out failed calculations if needed (optional, depends on desired output)
    # result_df = result_df[result_df["e_dist"] != -1.0]

    result_df = result_df.sort_values(by="e_dist").reset_index(drop=True)
    result_df["rank"] = result_df.index

    # Add boolean columns for each sgRNA
    for gRNA_name_tmp in gRNA_list_target:
        result_df[gRNA_name_tmp] = result_df["combis"].apply(
            lambda x: (gRNA_name_tmp in x[0]) or (gRNA_name_tmp in x[1]) if isinstance(x, (list, tuple)) and len(x) == 2 else False
        )

    return result_df

In [12]:
# --- Main Processing Loop ---

# gRNA_region_dict: Dict mapping region name to list/array of sgRNAs
# gRNA_dict: Dict mapping sgRNA name to list/array of cells
# X: Data matrix (e.g., PCA) indexed appropriately
# device: Computation device ('cuda', 'cpu')
# combi_count: Parameter for combination generation (e.g., 6)
# result_df_dict: Dictionary to store results per target
combi_count = 6

result_df_dict = {} # Initialize result dictionary

test_region_np = np.array(list(gRNA_region_dict.keys()))
test_region_np = test_region_np[~(test_region_np=="non-targeting")]

print(f"Processing {len(test_region_np)} target regions...")

# Outer loop iterates through target regions
pbar_targets = tqdm(test_region_np, desc="Overall Target Regions")
for target in pbar_targets:
    pbar_targets.set_postfix({
            "Current Target": target
        })

    gRNA_list_target = np.array(gRNA_region_dict[target])
    if len(gRNA_list_target) < 2:
        print(f"Skipping target '{target}': Needs at least 2 gRNAs, found {len(gRNA_list_target)}.")
        continue

    # 1. Generate combinations of sgRNA groups
    # print(f"\nGenerating combinations for target: {target}") # Optional verbose print
    total_combis = generate_sgrna_group_combinations(gRNA_list_target, combi_count)

    if not total_combis:
        print(f"Skipping target '{target}': No valid combinations generated.")
        continue

    # print(f"Generated {len(total_combis)} unique combinations for {target}.") # Optional verbose print

    # 2. Calculate distance for each combination
    res = []
    # Use a nested tqdm for combinations within a target, leaving the outer one clean
    #pbar_combis = tqdm(enumerate(total_combis), total=len(total_combis), desc=f"Target '{target}' Combinations", leave=False)
    for index_combi, (combi_test1, combi_test2) in enumerate(total_combis):
        # Get cells for the current combination
        cell_test1, cell_test2 = get_cells_for_sgrna_groups(combi_test1, combi_test2, gRNA_dict)
        
        # Try calculating on the primary device (potentially GPU)
        # Calculate distance with GPU/CPU fallback
        distance = calculate_distance(pca_df, cell_test1, cell_test2, device)
        res.append(distance)

    # 3. Format results into a DataFrame
    # print(f"\nFormatting results for target: {target}") # Optional verbose print
    result_df = format_results_dataframe(res, total_combis, gRNA_list_target)

    # Store the result DataFrame
    result_df_dict[target] = result_df
    # print(f"Finished processing target: {target}") # Optional verbose print

print("\nAll target regions processed.")
# Now result_df_dict contains the DataFrames for each target region.

Processing 2227 target regions...


Overall Target Regions:   9%|▉         | 211/2227 [01:06<05:36,  6.00it/s, Current Target=NR4A3:P1P2]                                                                   

Skipping target 'CGGBP1:P1': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  29%|██▉       | 647/2227 [03:00<08:05,  3.25it/s, Current Target=SKOR2:ENST00000400404.1,ENST00000425639.1]  

Skipping target 'SMARCC1:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  32%|███▏      | 708/2227 [03:17<08:54,  2.84it/s, Current Target=ZNF621:P1P2]                              

Skipping target 'SMARCC2:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  39%|███▉      | 871/2227 [03:58<03:53,  5.81it/s, Current Target=IKZF3:P1P2]                                                                   

Skipping target 'CD29': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  40%|███▉      | 885/2227 [04:01<04:35,  4.87it/s, Current Target=TULP1:P1P2]                                

Skipping target 'GTF2H2:ENST00000274400.5,ENST00000330280.7': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  41%|████      | 916/2227 [04:08<03:46,  5.79it/s, Current Target=ZNF599:P1P2]               

Skipping target 'PSMB1:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  42%|████▏     | 935/2227 [04:14<07:29,  2.87it/s, Current Target=ZNF251:P1P2] 

Skipping target 'SRCAP:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  43%|████▎     | 962/2227 [04:21<06:39,  3.17it/s, Current Target=ZNF320:P1P2]                              

Skipping target 'CBFA2T3:P1': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  46%|████▌     | 1021/2227 [04:37<07:11,  2.80it/s, Current Target=SMARCA2:P1P2]           

Skipping target 'IGHMBP2:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  47%|████▋     | 1047/2227 [04:44<05:25,  3.63it/s, Current Target=ZNF184:P1P2]              

Skipping target 'MSLN:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  48%|████▊     | 1071/2227 [04:49<04:32,  4.25it/s, Current Target=TEF:P2]      

Skipping target 'NGFRAP1': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  50%|█████     | 1115/2227 [04:59<04:15,  4.36it/s, Current Target=ZNF229:P1P2] 

Skipping target 'PHB2:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  65%|██████▌   | 1458/2227 [06:18<02:57,  4.33it/s, Current Target=THYN1:P1P2]                                 

Skipping target 'CXXC1:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  69%|██████▊   | 1526/2227 [06:33<02:33,  4.56it/s, Current Target=ZHX1:P1P2]              

Skipping target 'TSHZ2:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  71%|███████▏  | 1590/2227 [06:47<02:14,  4.72it/s, Current Target=ELK4:P1P2]               

Skipping target 'NCL:P1': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  72%|███████▏  | 1600/2227 [06:49<01:30,  6.91it/s, Current Target=SMAD4:P1P2]  

Skipping target 'GTF2F2:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  72%|███████▏  | 1606/2227 [06:50<02:18,  4.48it/s, Current Target=OR14J1]                  

Skipping target 'WDHD1:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  75%|███████▌  | 1671/2227 [07:03<02:04,  4.47it/s, Current Target=PCGF2:P2]                                 

Skipping target 'CRTC3:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  76%|███████▌  | 1693/2227 [07:07<01:00,  8.82it/s, Current Target=LBX2:P1P2]             

Skipping target 'ZNF20:P1': Needs at least 2 gRNAs, found 1.
Skipping target 'NCOA3:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  78%|███████▊  | 1740/2227 [07:17<01:21,  5.95it/s, Current Target=ZNF71:P1P2]               

Skipping target 'RELB:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  79%|███████▉  | 1766/2227 [07:23<01:04,  7.11it/s, Current Target=ESX1:P1P2]    

Skipping target 'E2F3:P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  80%|████████  | 1782/2227 [07:26<01:03,  6.98it/s, Current Target=ID3:P1P2]                                                                      

Skipping target 'ARID1B:P1': Needs at least 2 gRNAs, found 1.
Skipping target 'ZFYVE20:P1P2': Needs at least 2 gRNAs, found 1.
Skipping target 'E4F1:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  80%|████████  | 1788/2227 [07:26<01:01,  7.10it/s, Current Target=LHX3:P2]     

Skipping target 'MTA1:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  80%|████████  | 1791/2227 [07:27<01:32,  4.73it/s, Current Target=KLF14:P1P2] 

Skipping target 'MARS:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  81%|████████▏ | 1810/2227 [07:29<00:38, 10.74it/s, Current Target=LEF1:P1]     

Skipping target 'ZNF592:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  83%|████████▎ | 1849/2227 [07:36<00:43,  8.62it/s, Current Target=ZNF850:P1P2]                              

Skipping target 'ZNF207:P1': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  84%|████████▍ | 1866/2227 [07:40<00:54,  6.63it/s, Current Target=MEF2B:P1P2]  

Skipping target 'ADNP:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  85%|████████▍ | 1891/2227 [07:46<00:42,  7.96it/s, Current Target=ZNF625:P1P2]                              

Skipping target 'HIRA:P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  86%|████████▌ | 1912/2227 [07:49<00:27, 11.59it/s, Current Target=ZNF354A:P1P2]            

Skipping target 'GMEB1:P1P2': Needs at least 2 gRNAs, found 1.
Skipping target 'POU5F1:P1': Needs at least 2 gRNAs, found 1.
Skipping target 'DAXX:P1': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  86%|████████▋ | 1924/2227 [07:50<00:31,  9.53it/s, Current Target=ZNF829:P1P2] 

Skipping target 'CBFB:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  87%|████████▋ | 1942/2227 [07:55<01:06,  4.31it/s, Current Target=ZNF620:P1P2]

Skipping target 'DMAP1:P1P2': Needs at least 2 gRNAs, found 1.
Skipping target 'PPP1CB:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  88%|████████▊ | 1949/2227 [07:56<00:37,  7.37it/s, Current Target=BCLAF1:P1P2]

Skipping target 'TAF13:P1P2': Needs at least 2 gRNAs, found 1.
Skipping target 'ERG:P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  88%|████████▊ | 1955/2227 [07:56<00:42,  6.34it/s, Current Target=FOXB2:ENST00000376708.1]

Skipping target 'ZNF709:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  89%|████████▉ | 1978/2227 [08:00<00:26,  9.55it/s, Current Target=MYCBP:P1P2]              

Skipping target 'SON:P1': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  92%|█████████▏| 2044/2227 [08:10<00:34,  5.31it/s, Current Target=SOX10:P1P2]                                

Skipping target 'TFRC': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  92%|█████████▏| 2058/2227 [08:13<00:33,  5.07it/s, Current Target=AHR:P1P2]               

Skipping target 'CBFA2T2:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  93%|█████████▎| 2071/2227 [08:16<00:22,  6.83it/s, Current Target=ZNF347:P1P2]            

Skipping target 'CD55': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  94%|█████████▍| 2099/2227 [08:20<00:11, 10.96it/s, Current Target=TBX6:P1P2]               

Skipping target 'WHSC1:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  95%|█████████▍| 2106/2227 [08:22<00:25,  4.68it/s, Current Target=NEUROG3:P1P2]

Skipping target 'XRCC6:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  95%|█████████▍| 2113/2227 [08:22<00:16,  6.97it/s, Current Target=ZNF157:ENST00000377073.3]

Skipping target 'NFKB1:P1P2': Needs at least 2 gRNAs, found 1.
Skipping target 'ZNF704:P1P2': Needs at least 2 gRNAs, found 1.
Skipping target 'ZNF142:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  96%|█████████▌| 2132/2227 [08:26<00:21,  4.38it/s, Current Target=ZNF337:P1P2]             

Skipping target 'ZNF444:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  98%|█████████▊| 2178/2227 [08:32<00:03, 13.02it/s, Current Target=FOXE1:P1P2]               

Skipping target 'PUF60:P1P2': Needs at least 2 gRNAs, found 1.
Skipping target 'CREM:ENST00000395895.2': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  98%|█████████▊| 2183/2227 [08:32<00:02, 16.75it/s, Current Target=AIRE:P1P2]  

Skipping target 'B2M': Needs at least 2 gRNAs, found 1.


Overall Target Regions:  99%|█████████▊| 2198/2227 [08:33<00:01, 18.92it/s, Current Target=ZNF385A:P1]               

Skipping target 'EBF4:P1P2': Needs at least 2 gRNAs, found 1.
Skipping target 'MALAT1': Needs at least 2 gRNAs, found 1.
Skipping target 'ZNF395:P1P2': Needs at least 2 gRNAs, found 1.
Skipping target 'DNAJC21:P1P2': Needs at least 2 gRNAs, found 1.
Skipping target 'GLI2:ENST00000452319.1': Needs at least 2 gRNAs, found 1.
Skipping target 'KLF16:P1P2': Needs at least 2 gRNAs, found 1.
Skipping target 'MYRF:P1': Needs at least 2 gRNAs, found 1.
Skipping target 'HDGF:P1P2': Needs at least 2 gRNAs, found 1.
Skipping target 'SOX18:P1P2': Needs at least 2 gRNAs, found 1.
Skipping target 'TERF2:P1P2': Needs at least 2 gRNAs, found 1.


Overall Target Regions: 100%|██████████| 2227/2227 [08:33<00:00,  4.34it/s, Current Target=ZNF700:P1P2] 

Skipping target 'PURG:P1P2': Needs at least 2 gRNAs, found 1.
Skipping target 'OR9Q1': Needs at least 2 gRNAs, found 1.
Skipping target 'CTNNB1:P1P2': Needs at least 2 gRNAs, found 1.
Skipping target 'PFDN5:P1P2': Needs at least 2 gRNAs, found 1.
Skipping target 'SACM1L:P1P2': Needs at least 2 gRNAs, found 1.
Skipping target 'SMARCA5:P1P2': Needs at least 2 gRNAs, found 1.
Skipping target 'DNMT1:P1P2': Needs at least 2 gRNAs, found 1.
Skipping target 'CDK1:P1P2': Needs at least 2 gRNAs, found 1.
Skipping target 'ZNF700:P1P2': Needs at least 2 gRNAs, found 1.

All target regions processed.





In [13]:
target_arr = np.array(list(result_df_dict.keys()))

In [14]:
def determine_batch_size(total_cell_num, batch_num_basic):
    """
    Determines the appropriate batch size based on the total number of cells.

    Args:
        total_cell_num (int): The total number of cells for the current target.
        batch_num_basic (int): The default base batch number.

    Returns:
        int: The calculated batch size (at least 1).
    """
    if total_cell_num > 20000:
        batch_num = 1 # Smallest batch size for very large datasets
    elif total_cell_num > 5000:
        batch_num = batch_num_basic // 4 # Reduced batch size for large datasets
    elif total_cell_num > 300:
        batch_num = batch_num_basic // 2 # Reduced batch size for medium datasets
    else:
        batch_num = batch_num_basic # Basic batch size for smaller datasets
    # Ensure batch size is at least 1
    return max(1, batch_num)

def run_disco_test(X, total_cell_list, device, batch_num, total_permute_disco, target_name):
    """
    Runs the disco test, attempting GPU first and falling back to CPU.

    Args:
        X (array-like): The data matrix (e.g., PCA coordinates).
        total_cell_list (list): A list where each element is a list/array of cell IDs
                                corresponding to an sgRNA group.
        device (str): The primary computation device ('cuda', 'cpu', etc.).
        batch_num (int): The batch number for the test.
        total_permute_disco (int): The total number of permutations used in the disco test.
        target_name (str): The name of the target region (for logging).

    Returns:
        float: The calculated disco test p-value, or np.nan if the test fails completely.
    """
    # Removed print statement for batch size here, logged before calling
    try:
        # Attempt calculation on the primary device (potentially GPU)
        # print(f"    Attempting disco test on device: {device} for {target_name}", flush=True) # Optional verbose
        obs_fvalue, fvalue_list = energy_distance_calc.disco_test(X, total_cell_list, device, batch_num=batch_num)
        obs_fvalue = obs_fvalue.numpy(), 
        fvalue_list = fvalue_list.numpy()
        
        disco_pvalue = np.sum(fvalue_list > obs_fvalue) / total_permute_disco
        #print(f"    Disco test completed on {device} for {target_name}. p-value: {disco_pvalue:.4f}", flush=True)
        return disco_pvalue
    except Exception as e_gpu:
        error_str = str(e_gpu).lower()
        if 'memory' in error_str or 'cuda' in error_str:
            print(f"    GPU execution failed for {target_name} (likely OOM). Falling back to CPU. Error: {e_gpu}", flush=True)
            try:
                # print(f"    Attempting disco test on CPU for {target_name}", flush=True) # Optional verbose
                obs_fvalue, fvalue_list = util_functions.disco_test(X, total_cell_list, "cpu", batch_num=batch_num)
                disco_pvalue = np.sum(np.asarray(fvalue_list) > obs_fvalue) / total_permute_disco
                #print(f"    Disco test completed on CPU for {target_name}. p-value: {disco_pvalue:.4f}", flush=True)
                return disco_pvalue
            except Exception as e_cpu:
                print(f"    CPU execution also failed for {target_name}. Skipping disco test. Error: {e_cpu}", flush=True)
                return np.nan
        else:
            #print(f"    GPU execution failed for {target_name} (non-memory error). Skipping disco test. Error: {e_gpu}", flush=True)
            return np.nan


def calculate_hypergeometric_pvalue(result_df, gRNA_name, significance_fraction):
    """
    Calculates the p-value for a single sgRNA using the hypergeometric test.
    Tests if the sgRNA is overrepresented in the highest e-dist ranks.

    Args:
        result_df (pd.DataFrame): DataFrame containing sorted e-distance results
                                  and boolean flags for sgRNA membership per combination.
        gRNA_name (str): The name of the sgRNA to calculate the p-value for.
        significance_fraction (float): The fraction of top ranks to consider significant.

    Returns:
        float: The calculated hypergeometric p-value (right-tailed). Returns 1.0
               if input parameters are invalid for the test.
    """
    M = result_df.shape[0] # Population size (total combinations)
    if M == 0:
        # This case should ideally be caught before calling this function
        return 1.0

    n = result_df[gRNA_name].sum() # Successes in population (combinations with the gRNA)
    # Ensure num_sig_diff calculation avoids zero, minimum 1 draw needed for test
    num_sig_diff = max(1, int(M * significance_fraction)) # Number of draws (top ranked combinations)

    # Get top ranks (highest e_dist -> bottom rows if sorted ascending)
    result_df_right = result_df.iloc[-num_sig_diff:]
    x = result_df_right[gRNA_name].sum() # Successes in draw (gRNA presence in top ranks)
    N = num_sig_diff # Number of draws

    # Validate parameters before passing to hypergeom.sf
    if not (0 <= x <= min(n, N) and 0 < N <= M and 0 <= n <= M):
        print(f"    Warning: Invalid parameters for hypergeometric test for {gRNA_name}. "
              f"[M={M}, n={n}, N={N}, x={x}]. Returning p=1.0")
        return 1.0

    # P(X >= x) = 1 - P(X <= x-1) = sf(x-1)
    p_val = hypergeom.sf(x - 1, M, n, N)
    return round(p_val, 4)


In [15]:
# --- Constants ---
# Define constants near the top for clarity and easy modification
DISCO_P_VALUE_THRESHOLD = 0.05
HYPERGEOM_SIGNIFICANCE_FRACTION = 0.5 # Top 10% for hypergeometric test

# --- Main Processing Logic ---

# Assume these variables are defined from previous steps:
# result_df_dict: Dict mapping target name to results DataFrame
# gRNA_region_dict: Dict mapping target region name to list/array of sgRNAs
# gRNA_dict: Dict mapping sgRNA name to list/array of cells
# X: Data matrix (e.g., PCA)
# device: Computation device ('cuda', 'cpu')
# batch_num_basic: Default batch number (integer)
# total_permute_disco: Total permutations for disco test (integer)


# Initialize dictionaries
p_val_dict = {}
disco_val_dict = {}

print("Starting Step 2: Disco Tests and Hypergeometric Outlier Calculation...")

target_keys = list(result_df_dict.keys())
num_targets = len(target_keys)

for i, target in tqdm(enumerate(target_keys),total=len(target_keys)):

    # --- 1. Data Preparation ---
    if target not in result_df_dict or target not in gRNA_region_dict:
        print(f"  Warning: Missing data for target '{target}'. Skipping.")
        continue

    result_df = result_df_dict[target]
    gRNA_list_target = gRNA_region_dict[target] # Renamed for clarity

    if not isinstance(gRNA_list_target, (list, np.ndarray)) or len(gRNA_list_target) == 0:
        print(f"  Warning: Invalid or empty gRNA list for target '{target}'. Skipping.")
        continue

    # Prepare cell list and count total cells, handle errors
    total_cell_list = []
    total_cell_num = 0
    valid_groups = 0 # Count groups that actually have cells
    try:
        cell_lists_temp = []
        for gRNA_name in gRNA_list_target:
            cells = gRNA_dict[gRNA_name] # Raises KeyError if missing
            cell_lists_temp.append(cells)
            if len(cells) > 0:
                 valid_groups += 1
        # Concatenate only non-empty lists if needed, but disco_test might handle list of lists directly
        total_cell_list = cell_lists_temp # Pass the full list structure
        # Calculate total cells more carefully if concatenation is memory intensive
        total_cell_num = sum(len(cells) for cells in total_cell_list)

    except KeyError as e:
        print(f"  Error: sgRNA name {e} not found in gRNA_dict for target '{target}'. Skipping.")
        # Clean up potentially large objects for this iteration before continuing
        del result_df, gRNA_list_target
        gc.collect()
        continue

    #print(f"  Target: {target}, Num gRNAs: {len(gRNA_list_target)}, Total Cells: {total_cell_num}, Groups with Cells: {valid_groups}")

    # --- 2. Determine Batch Size ---
    batch_num = determine_batch_size(total_cell_num, config["gRNA_filtering"]["batch_num_basic"])
    #print(f"  Determined batch size: {batch_num}")

    # --- 3. Run Disco Test ---
    if valid_groups < 2:
        print(f"  Skipping disco test for '{target}': Fewer than 2 sgRNAs have associated cells.")
        disco_pvalue = np.nan
    else:
        # Assuming disco_test can handle the list of cell lists directly
        disco_pvalue = run_disco_test(pca_df, total_cell_list, device, batch_num, config["gRNA_filtering"]["total_permute_disco"], target)

    disco_val_dict[target] = disco_pvalue

    # --- 4. Calculate Individual sgRNA p-values (Conditional) ---
    # Check disco test outcome
    if math.isnan(disco_pvalue) or disco_pvalue > DISCO_P_VALUE_THRESHOLD or len(gRNA_list_target) == 2:
        if math.isnan(disco_pvalue):
             log_msg = f"Disco test failed or skipped for '{target}'."
        elif disco_pvalue > DISCO_P_VALUE_THRESHOLD:
             log_msg = f"Disco test p-value ({disco_pvalue:.4f}) > {DISCO_P_VALUE_THRESHOLD} for '{target}'."
        else: # len == 2 case
             log_msg = f"Target '{target}' has only 2 sgRNAs."
        #print(f"  {log_msg} Assigning default p-value (1.0) to its sgRNAs.")

        for gRNA_name_tmp in gRNA_list_target:
            p_val_dict[gRNA_name_tmp] = 1.0

    else:
        # Disco test significant, proceed with hypergeometric tests
        #print(f"  Disco test significant (p={disco_pvalue:.4f}). Calculating hypergeometric p-values for sgRNAs in '{target}'.")
        if result_df.empty:
            print(f"  Warning: result_df for target '{target}' is empty. Cannot calculate hypergeometric p-values. Assigning 1.0.")
            for gRNA_name_tmp in gRNA_list_target:
                p_val_dict[gRNA_name_tmp] = 1.0
        else:
            # Check if result_df contains expected columns
            missing_cols = [g for g in gRNA_list_target if g not in result_df.columns]
            if missing_cols:
                print(f"  Error: Missing gRNA columns in result_df for target '{target}': {missing_cols}. Assigning 1.0 p-value.")
                for gRNA_name_tmp in gRNA_list_target:
                    p_val_dict[gRNA_name_tmp] = 1.0
            else:
                # Calculate p-value for each sgRNA
                for gRNA_name_tmp in gRNA_list_target:
                    p_value = calculate_hypergeometric_pvalue(result_df, gRNA_name_tmp, HYPERGEOM_SIGNIFICANCE_FRACTION)
                    p_val_dict[gRNA_name_tmp] = p_value
                    # print(f"    p-value for {gRNA_name_tmp}: {p_value}") # Optional verbose print
    # gc.collect() # Uncomment if memory issues are severe, but adds overhead

# --- 5. Final Output ---
print("\nFinished processing all targets.")
print("Aggregating results and saving to CSV...")

outlier_df = pd.Series(p_val_dict)
outlier_df.name = "pval_outlier"

output_csv_file = os.path.join(config["output_file_name_list"]["OUTPUT_FOLDER"],
                               config["output_file_name_list"]["targeting_outlier_table"])

try:
    outlier_df.to_csv(output_csv_file, header=True)
    print(f"Successfully saved outlier p-values to '{output_csv_file}'")
except Exception as e:
    print(f"Error: Failed to save results to '{output_csv_file}'. Error: {e}")

Starting Step 2: Disco Tests and Hypergeometric Outlier Calculation...


100%|██████████| 2155/2155 [18:03<00:00,  1.99it/s]


Finished processing all targets.
Aggregating results and saving to CSV...
Successfully saved outlier p-values to './pipeline_output/targeting_outlier_table.csv'





### Filtering non-targeting gRNAs

In [16]:
non_target_gRNA_list = [item for item in list(gRNA_dict.keys()) 
                        if (item.startswith("non-targeting")) and 
                        (len(gRNA_dict[item])>20)]
cell_id_nontarget_list = [gRNA_dict[key] for key in non_target_gRNA_list]

In [17]:
res = energy_distance_calc.pairwise_torch(pca_df,cell_id_nontarget_list,device,vardose=True)

100%|██████████| 1559/1559 [00:03<00:00, 393.71it/s]


In [18]:
# make Dataframe from results
pairwise_list = np.zeros((len(non_target_gRNA_list),
                          len(non_target_gRNA_list)
                         ))
for p1, p2, val in res:
    pairwise_list[p1,p2]=val
    pairwise_list[p2][p1]=val

df = pd.DataFrame(pairwise_list.copy(),
                  index=non_target_gRNA_list,
                  columns=non_target_gRNA_list) 

df.index.name = "sgRNA"
df.columns.name = "sgRNA"
df.name = 'pairwise PCA distances'

sigmas = np.diag(df.values)
target_estats = 2 * df.values - sigmas - sigmas[:, np.newaxis]

kmeans = KMeans(n_clusters=3, random_state=0, n_init="auto").fit(target_estats)
label,count = np.unique(kmeans.labels_,return_counts=True)
print("k-means analysis of non-targeting")
print(label,count)
#This is the non-target list for TF perturb seq

largest_group_label = label[np.argmax(count)]
largest_group_ratio = np.round(np.sum(kmeans.labels_==largest_group_label)/len(kmeans.labels_),3)

print("largest group: ",largest_group_label)
print("largest group ratio: ",largest_group_ratio)

k-means analysis of non-targeting
[0 1 2] [462   1  95]
largest group:  0
largest group ratio:  0.828


In [19]:
non_target_gRNA_name_df = pd.DataFrame(index=non_target_gRNA_list)
non_target_gRNA_name_df["pval_outlier"] = (kmeans.labels_==largest_group_label).astype(int)

In [20]:
#output the result
nt_output_csv_file = os.path.join(config["output_file_name_list"]["OUTPUT_FOLDER"],
                                  config["output_file_name_list"]["non_targeting_outlier_table"])

non_target_gRNA_name_df.to_csv(nt_output_csv_file)