In [1]:
# Core libraries
import os, psutil, torch
from collections import defaultdict
from tqdm import tqdm
# Scientific and Data libraries
import numpy as np, pandas as pd
# Multiprocessing
import multiprocessing as mp
from time import time
import logging
logging.basicConfig(level=logging.INFO)

In [2]:
def build_cicero_conns_matrix(batch):
    idx0 = batch['Peak1_idx'].to_numpy(dtype=np.int64)
    idx1 = batch['Peak2_idx'].to_numpy(dtype=np.int64)
    values = batch['coaccess'].to_numpy(dtype=np.float16)

    mask = idx0 <= idx1
    upper_indices = np.stack([idx0[mask], idx1[mask]], axis=0)
    upper_values = values[mask]
    return upper_indices, upper_values

def get_binding_peaks_for_motif_pair(motif_name1, motif_name2, cicero_peaks_index):
    motif_arrays = []
    for motif in [motif_name1, motif_name2]:
        motif_df = pd.read_csv(f"{out_path}/out_files/txt_files_for_cicero/{motif}.txt", sep='\t', header=0)[['chrom', 'start', 'end']].drop_duplicates().reset_index(drop=True)
        motif_df['concat'] = motif_df['chrom'].astype(str) + "_" + motif_df['start'].astype(str) + "_" + motif_df['end'].astype(str)
        motif_df['peak_index'] = motif_df['concat'].map(cicero_peaks_index)
        motif_df = motif_df.dropna()
        motif_arrays.append(np.array(motif_df['peak_index'].values, dtype=np.int32))
    return motif_arrays

def motif_peak_pair_intersection_cicero_connections(chunk):
    result_chunk = []
    for idx1 in chunk:
        for idx2 in motif2_peak_idx:
            try:
                if idx1 <= idx2:
                    result_chunk.append((idx1, idx2, index_value_dict[(idx1, idx2)]))
                else:
                    result_chunk.append((idx2, idx1, index_value_dict[(idx2, idx1)]))
            except KeyError:
                pass
    return result_chunk

In [6]:
# main
logging.debug("Starting the script")
wd = '/ocean/projects/cis240075p/skeshari/igvf/bcell2/primaryBCell/'
out_path = os.path.join(wd, 'out_data', 'ce_seek')
os.makedirs(f"{out_path}/figures", exist_ok=True)
os.makedirs(f"{out_path}/out_files", exist_ok=True)
os.makedirs(f"{out_path}/out_files/txt_files_for_cicero", exist_ok=True)
logging.debug("Directories created, starting the script")
# Load the enriched TFs
slide_enriched_tfs = pd.read_csv((os.path.join(wd, 'out_data','comb_ctrl/out_files/SLIDE_LF_enriched_TFs.csv')), header=0)
# Directory containing the text files with JASPAR motifs
search_directory = os.path.join(wd, 'out_data', 'out_other_methods', 'JASPAR2024_CORE_vertebrates_non-redundant_homer')
input_file = f"{out_path}/out_files/concatenated_file.txt"
motif_name_list = []
# Get the list of motif_names as in JASPAR
with open(input_file, 'r') as infile:
    for line in infile:
        if line.startswith('>'):
            motif_name_list.append(line.strip().split('\t')[1])
logging.debug("Motif names loaded")
# Updating the TF names in slide_enriched_tfs
mapping_dict = {key.split('(')[1].split(')')[0]: key for key in motif_name_list}
slide_enriched_tfs['TF1'] = slide_enriched_tfs['TF1_jid'].map(mapping_dict)
slide_enriched_tfs['TF2'] = slide_enriched_tfs['TF2_jid'].map(mapping_dict)
slide_enriched_tfs = slide_enriched_tfs.dropna().reset_index(drop=True)
comb_of_interest = list(zip(slide_enriched_tfs['TF1'], slide_enriched_tfs['TF2']))
logging.info(f"STEP1: Number of combinations of interest are: {len(comb_of_interest)}")
# Build Cicero connections matrix
cicero_conns = pd.read_csv(f"{wd}/out_data/cicero_output/cicero_connections.csv", header=0, index_col=0).dropna()
cicero_conns = cicero_conns[cicero_conns['coaccess'] != 0].reset_index(drop=True)
# Precompute indices
cicero_peaks_set = pd.concat([cicero_conns['Peak1'], cicero_conns['Peak2']]).unique()
cicero_peaks_index = {peak: idx for idx, peak in enumerate(sorted(cicero_peaks_set))}
cicero_conns['Peak1_idx'] = cicero_conns['Peak1'].map(cicero_peaks_index)
cicero_conns['Peak2_idx'] = cicero_conns['Peak2'].map(cicero_peaks_index)
num_peaks = len(cicero_peaks_index)
logging.info(f"Number of peaks: {num_peaks}")

available_memory = psutil.virtual_memory().available / (1024**3)
batch_size = int(1e6 if available_memory > 2 else 1e5)
batches = [cicero_conns.iloc[i:i + batch_size] for i in range(0, len(cicero_conns), batch_size)]
logging.debug("Cicero connections matrix has been preprocessed")
# Process batches
matrix = []
for batch in tqdm(batches, total=len(batches), desc='Building Cicero connections matrix'):
    matrix.append(build_cicero_conns_matrix(batch))
# Combine results
upper_indices, upper_values = matrix[0]
for mtx in matrix[1:]:
    upper_indices = np.concatenate([upper_indices, mtx[0]], axis=1)
    upper_values = np.concatenate([upper_values, mtx[1]])
index_value_dict = {
    (upper_indices[0, i], upper_indices[1, i]): upper_values[i]
    for i in range(upper_indices.shape[1])
}
logging.info("STEP2: Cicero connections matrix has been built")
logging.debug('Clearing the memory')
del cicero_conns, matrix
torch.cuda.empty_cache()

corr_peak_for_motif_pair = []
num_chunks = mp.cpu_count()
for preferred_motif_set_name in tqdm(comb_of_interest, desc='Processing combinations'):
    motif1_peak_idx, motif2_peak_idx = get_binding_peaks_for_motif_pair(*preferred_motif_set_name, cicero_peaks_index)
    union_peak_index = set(np.concatenate([motif1_peak_idx, motif2_peak_idx]))
    logging.info(f"STEP3: Binding peaks for motifs {preferred_motif_set_name} have been loaded")
    chunks = np.array_split(motif1_peak_idx, num_chunks)
    with mp.Pool(num_chunks) as pool:
        results = pool.map(motif_peak_pair_intersection_cicero_connections, chunks)
    logging.info(f"STEP4: Pair of peaks found in cicero connections has been counted for {preferred_motif_set_name}")
    result = [item for sublist in results for item in sublist]
    cnt_thr_pos_idx = len([x for x in result if abs(x[2]) >= 0.05]) 
    jaccard_index = cnt_thr_pos_idx / len(union_peak_index)
    corr_peak_for_motif_pair.append((*preferred_motif_set_name, cnt_thr_pos_idx, len(union_peak_index)))

# Save the results
results_df = pd.DataFrame(corr_peak_for_motif_pair, columns=['TF1', 'TF2', 'cnt_thr_pos_idx', 'union_peak_index'])
results_df.to_csv(f"{out_path}/out_files/correlation_results.csv", index=False)
logging.info("Results have been saved")


INFO:root:STEP1: Number of combinations of interest are: 96
INFO:root:Number of peaks: 137091
Building Cicero connections matrix: 100%|██████████| 9/9 [00:00<00:00, 60.42it/s]
INFO:root:STEP2: Cicero connections matrix has been built
Loading peaks for motifs: 100%|██████████| 2/2 [00:00<00:00, 10.57it/s]
INFO:root:STEP3: Binding peaks for motifs ('BACH2(MA1470.2)', 'IRF4(MA1419.2)') have been loaded
INFO:root:STEP4: Pair of peaks found in cicero connections has been counted for ('BACH2(MA1470.2)', 'IRF4(MA1419.2)')
Loading peaks for motifs: 100%|██████████| 2/2 [00:01<00:00,  1.84it/s]
INFO:root:STEP3: Binding peaks for motifs ('BACH2(MA1470.2)', 'PRDM1(MA0508.4)') have been loaded
INFO:root:STEP4: Pair of peaks found in cicero connections has been counted for ('BACH2(MA1470.2)', 'PRDM1(MA0508.4)')
Loading peaks for motifs: 100%|██████████| 2/2 [00:00<00:00,  2.09it/s]
INFO:root:STEP3: Binding peaks for motifs ('BCL11A(MA2324.1)', 'IRF4(MA1419.2)') have been loaded
INFO:root:STEP4: Pai

KeyboardInterrupt: 