In [1]:
import sys
import os

import pandas as pd
import numpy as np
import scanpy as sc

from sklearn.cluster import KMeans
import scipy.stats
from scipy.stats import hypergeom
from sklearn.metrics import pairwise_distances
from itertools import combinations

from sklearn.manifold import MDS,Isomap,TSNE
from sklearn.cluster import AffinityPropagation,AgglomerativeClustering

import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams.update({'axes.labelsize' : 'large',
                     'pdf.fonttype' : 42
                    }) 
from matplotlib.backends.backend_pdf import PdfPages
from adjustText import adjust_text
import umap

from tqdm import tqdm
from tqdm.contrib.concurrent import process_map

import gc
import warnings
import time
import pickle
import json

from sklearn.metrics import pairwise_distances
from multiprocessing import Pool
import torch

from importlib import reload
import util_functions
import energy_distance_calc

  from .autonotebook import tqdm as notebook_tqdm


<h3>Load data</h3>

In [2]:
json_fp = "./config.json"
json_fp_cluster = "./config_clustering.json"

with open(json_fp, 'r') as fp:
    config = json.load(fp)

with open(json_fp_cluster, 'r') as fp:
    config_clustering = json.load(fp)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

(pca_df,gRNA_dict) = util_functions.load_files(config["input_data"]["h5ad_file"],
                                               config["input_data"]["sgRNA_file"],
                                               os.path.join(config["output_file_name_list"]["OUTPUT_FOLDER"],
                                                            config["output_file_name_list"]["pca_table"]),
                                               os.path.join(config["output_file_name_list"]["OUTPUT_FOLDER"],
                                                            config["output_file_name_list"]["gRNA_dict"]),
                                               overwrite=False
                                              )

--- Processing PCA data ---
Loading existing PCA file '/project/GCRB/Hon_lab/s223695/Data_project/Perturb_seq_edist_pipeline/pipeline_output/pca_dataframe.pickle'.

--- Processing gRNA dictionary ---
Loading existing gRNA dictionary file '/project/GCRB/Hon_lab/s223695/Data_project/Perturb_seq_edist_pipeline/pipeline_output/gRNA_dictionary.pickle'.
gRNA dictionary loaded successfully. Found 11634 types of gRNAs.

--- Processing finished ---


In [3]:
sgRNA_outlier_df = pd.read_csv(os.path.join(config["output_file_name_list"]["OUTPUT_FOLDER"],
                                                  config["output_file_name_list"]["targeting_outlier_table"]),
                                     index_col=0)

nontargeting_outlier_df = pd.read_csv(os.path.join(config["output_file_name_list"]["OUTPUT_FOLDER"],
                                                         config["output_file_name_list"]["non_targeting_outlier_table"]),
                                            index_col=0)

In [4]:
clear_sgRNA_list = sgRNA_outlier_df[sgRNA_outlier_df["pval_outlier"]>0.05].index.tolist()
clear_nt_sgRNA_list = nontargeting_outlier_df[nontargeting_outlier_df["pval_outlier"]>0.05].index.tolist()

In [5]:
annotation_df = pd.read_csv(os.path.join(config["output_file_name_list"]["OUTPUT_FOLDER"],
                                         config["output_file_name_list"]["annotation_file"]),index_col=0)
annotation_df.head()

Unnamed: 0,protospacer_ID,target_transcript_name,target_gene_name,source,protospacer,reverse_compliment
0,DNAJC19_ B,DNAJC19,DNAJC19,pos_control,GGGAACTCCTGTAAGGTCAG,CTGACCTTACAGGAGTTCCC
1,POLR1D_ B,POLR1D,POLR1D,pos_control,GGGAAGCAAGGACCGACCGA,TCGGTCGGTCCTTGCTTCCC
2,OR5K2-2,OR5K2,OR5K2,neg_control,GAAAAAATTGTAGAGGAATA,TATTCCTCTACAATTTTTTC
3,SP1_+_53773993.23-P1P2-1,SP1:P1P2,SP1,target,GAAAAACGCGGACGCTGACG,CGTCAGCGTCCGCGTTTTTC
4,SP8_-_20826141.23-P1P2-2,SP8:P1P2,SP8,target,GAAAAAGATCCTCTGAGAGG,CCTCTCAGAGGATCTTTTTC


In [6]:
gRNA_region_dict = util_functions.get_gRNA_region_dict(annotation_df,gRNA_dict)

In [7]:
gRNA_region_clear_dict = {}

for key in gRNA_region_dict.keys():
    gRNA_list_tmp = [x for x in gRNA_region_dict[key] if x in clear_sgRNA_list]
    if len(gRNA_list_tmp)!=0:
        gRNA_region_clear_dict[key] = [x for x in gRNA_region_dict[key] if x in clear_sgRNA_list]
        
cell_per_region_dict = {}
for key in gRNA_region_clear_dict.keys():
    cell_list_tmp = [gRNA_dict[i] for i in gRNA_region_clear_dict[key]]
    cell_list_tmp = np.concatenate(cell_list_tmp)
    cell_per_region_dict[key] = np.unique(cell_list_tmp)

In [57]:
pval_df = pd.read_csv(os.path.join(config["output_file_name_list"]["OUTPUT_FOLDER"],
                                   config["output_file_name_list"]["edist_pvalue_table"]),index_col=0)

In [58]:
pval_df.head()

Unnamed: 0,cell_count,source,distance_0,pval_0,distance_1,pval_1,distance_2,pval_2,distance_3,pval_3,...,distance_17,pval_17,distance_18,pval_18,distance_19,pval_19,distance_mean,pval_mean,pval_mean_log,distance_mean_log
DNAJC19,837,pos_control,15.913452,0.0,10.693237,0.0,18.971924,0.0,11.587036,0.0,...,12.598999,0.0,15.845825,0.0,15.46875,0.0,14.111768,0.0,5.0,1.149581
OR5K2,560,neg_control,5.983276,0.019,9.272095,0.004,6.177002,0.018,11.869507,0.0,...,6.688721,0.016,6.685913,0.012,7.246826,0.011,7.710968,0.011,1.958213,0.887109
SP8:P1P2,168,target,40.250488,0.0,53.394409,0.0,40.954102,0.0,54.015259,0.0,...,42.507446,0.0,40.865356,0.0,46.430542,0.0,45.908685,0.0,5.0,1.661895
FOXN3:P2,397,target,3.064941,0.493,3.517212,0.389,3.052734,0.502,3.87854,0.311,...,2.015625,0.875,3.000854,0.543,2.961304,0.54,2.964667,0.5687,0.245109,0.471976
ZNF85:P1P2,495,target,2.021362,0.737,4.43042,0.119,1.848511,0.799,4.47168,0.111,...,2.093018,0.693,1.651611,0.867,2.349365,0.592,2.796436,0.4978,0.302936,0.446605


<h3>Aggregate by target</h3>

In [59]:
config_clustering["cutoff"]

{'pval_cutoff': 0.001, 'distance_cutoff': 0}

In [60]:
pval_df_sig = pval_df[(pval_df["pval_mean"]<config_clustering["cutoff"]["pval_cutoff"]) &
                      (pval_df["distance_mean"]>config_clustering["cutoff"]["distance_cutoff"])]

In [61]:
pval_df_sig

Unnamed: 0,cell_count,source,distance_0,pval_0,distance_1,pval_1,distance_2,pval_2,distance_3,pval_3,...,distance_17,pval_17,distance_18,pval_18,distance_19,pval_19,distance_mean,pval_mean,pval_mean_log,distance_mean_log
DNAJC19,837,pos_control,15.913452,0.000,10.693237,0.000,18.971924,0.000,11.587036,0.000,...,12.598999,0.000,15.845825,0.000,15.468750,0.0,14.111768,0.00000,5.000000,1.149581
SP8:P1P2,168,target,40.250488,0.000,53.394409,0.000,40.954102,0.000,54.015259,0.000,...,42.507446,0.000,40.865356,0.000,46.430542,0.0,45.908685,0.00000,5.000000,1.661895
ZNF616:P1P2,607,target,8.552612,0.004,13.160156,0.000,8.431763,0.001,14.387207,0.000,...,9.401123,0.001,9.552002,0.000,10.612793,0.0,10.505194,0.00085,3.065502,1.021404
HOXB4:P1P2,452,target,70.259888,0.000,91.526611,0.000,71.630493,0.000,92.941650,0.000,...,76.301514,0.000,71.574707,0.000,79.045410,0.0,79.393262,0.00000,5.000000,1.899784
HAND1:P1P2,2268,target,19.077148,0.000,23.584839,0.000,18.418457,0.000,21.271118,0.000,...,19.946167,0.000,16.974487,0.000,19.605225,0.0,20.143616,0.00000,5.000000,1.304137
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ZNF292:P1P2,828,target,8.705200,0.002,7.986572,0.001,9.748901,0.000,7.974609,0.001,...,8.506348,0.000,9.165283,0.000,8.748291,0.0,8.323004,0.00055,3.251812,0.920280
ID2:P2,310,target,22.038940,0.000,30.321289,0.000,23.096802,0.000,31.149170,0.000,...,23.799805,0.000,22.365723,0.001,25.813721,0.0,25.808527,0.00005,4.221849,1.411763
TADA2B:P1P2,1654,target,66.343140,0.000,72.220703,0.000,69.895386,0.000,71.029541,0.000,...,63.019287,0.000,67.896729,0.000,68.837402,0.0,69.054883,0.00000,5.000000,1.839194
FOXD3:P1,419,target,34.637329,0.000,26.681152,0.000,32.384644,0.000,29.576904,0.000,...,30.323975,0.000,36.122559,0.000,31.145752,0.0,30.838916,0.00000,5.000000,1.489099


In [62]:
region_list_sig = np.unique(pval_df_sig.index)

<h3>make energy dist map of significant hits</h3>

In [63]:
cell_id_list_target = [cell_per_region_dict[key] for key in region_list_sig]
print(len(cell_id_list_target))

67


In [64]:
combi = list(combinations(range(len(cell_id_list_target)),2)) + \
        [(x,x) for x in range(len(cell_id_list_target))]

In [65]:
downsampling = config["aggregate"]["downsampling_maximum"]
res = []

pbar=tqdm(combi)
for target1_idx, target2_idx in pbar:
    pbar.set_postfix({"target1":region_list_sig[target1_idx],
                      "target2":region_list_sig[target2_idx],
                      "mode": "GPU"
                     })
    cell_test1 = cell_id_list_target[target1_idx]
    cell_test2 = cell_id_list_target[target2_idx]
    
    if len(cell_test1) > downsampling:
        cell_test1 = cell_test1[:downsampling]
    if len(cell_test2) > downsampling:
        cell_test2 = cell_test2[:downsampling]
    
    obs_edist = None
    mode = device # Assume GPU initially

    try:
        # Attempt GPU calculation
        obs_edist = energy_distance_calc.permutation_test(
            pca_df, cell_test1, cell_test2,
            device, 1, 1,return_permute=False
        )

    except Exception as e_gpu:
        print(f"GPU calculation failed: {e_gpu}. Attempting CPU fallback...")
        # Clean up GPU memory before CPU attempt
        gc.collect()
        if torch.cuda.is_available():
             torch.cuda.empty_cache()

        mode = "CPU" # Switch mode for postfix and calculation

        try:
            # Attempt CPU calculation
            obs_edist = energy_distance_calc.permutation_test(
                pca_df, cell_test1, cell_test2,
                "cpu", 1, 1,return_permute=False
            )

        except Exception as e_cpu:
            # Both GPU and CPU attempts failed
            print(f"CPU calculation also failed: {e_cpu}")
            print("Skipping energy distance calculation for this iteration.")
            # Return None to indicate failure

        
    res += [(region_list_sig[target1_idx],region_list_sig[target2_idx],obs_edist.item())]

100%|██████████| 2278/2278 [00:05<00:00, 391.12it/s, target1=ZSCAN32:P1P2, target2=ZSCAN32:P1P2, mode=GPU]                            


In [66]:
pairwise_dict=dict(zip(region_list_sig,[" "]*len(region_list_sig)))
for key in pairwise_dict.keys():
    pairwise_dict[key] = dict(zip(region_list_sig,[" "]*len(region_list_sig)))

for p1, p2, val in tqdm(res):
    pairwise_dict[p1][p2]=val
    pairwise_dict[p2][p1]=val
target_estats = pd.DataFrame(pairwise_dict,index=region_list_sig,columns=region_list_sig) 

100%|██████████| 2278/2278 [00:00<00:00, 2348727.76it/s]


In [67]:
target_estats

Unnamed: 0,AHDC1:P1P2,ARID2:P1P2,ARNT:P1P2,ATXN7:ENST00000538065.1,BBX:P1P2,BSX:P1P2,CDX2:P1P2,CDX4:P1P2,DNAJC19,ELMSAN1:P2,...,ZNF439:P1P2,ZNF471:P1P2,ZNF518B:P1P2,ZNF606:P1P2,ZNF616:P1P2,ZNF688:P1P2,ZNF718:ENST00000400172.3,ZNF814:P1P2,ZSCAN20:P1P2,ZSCAN32:P1P2
AHDC1:P1P2,0.000000,56.921875,61.897339,15.753662,17.576538,16.145630,29.487549,15.056152,68.841553,15.847778,...,26.163086,17.691162,23.828125,21.753784,20.659302,17.130005,21.859009,15.590942,22.978149,18.243652
ARID2:P1P2,56.921875,0.000000,78.437622,42.417480,38.004272,48.586060,70.523682,44.999756,33.975586,55.930786,...,28.614990,58.317627,78.738770,34.922241,28.131714,50.265991,23.917114,51.427124,43.406372,56.164551
ARNT:P1P2,61.897339,78.437622,0.000000,36.194702,56.187744,37.771973,53.679077,47.497192,71.083862,38.875488,...,51.997192,43.087280,53.265259,39.937988,50.911621,36.061035,52.268311,50.799805,52.543945,46.248901
ATXN7:ENST00000538065.1,15.753662,42.417480,36.194702,0.000000,17.354858,3.397095,15.813232,3.530518,36.750977,5.611450,...,9.043701,5.382568,11.290039,3.807983,7.689819,3.232300,11.207886,6.673462,8.077515,7.249756
BBX:P1P2,17.576538,38.004272,56.187744,17.354858,0.000000,18.976562,44.100952,18.210327,34.635376,24.779297,...,19.077271,22.362671,37.002563,17.264404,14.414307,16.828125,17.099609,23.378174,23.606445,26.606812
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ZNF688:P1P2,17.130005,50.265991,36.061035,3.232300,16.828125,3.761230,15.821655,4.600220,39.110962,5.187500,...,11.799438,3.317993,9.344849,4.340820,10.633057,0.000000,13.090332,6.953125,8.998047,5.871948
ZNF718:ENST00000400172.3,21.859009,23.917114,52.268311,11.207886,17.099609,12.610352,24.509644,10.145386,32.883423,20.493652,...,5.404907,16.116089,32.668335,7.440674,4.321533,13.090332,0.000000,12.807861,9.967529,13.409058
ZNF814:P1P2,15.590942,51.427124,50.799805,6.673462,23.378174,7.283936,13.903198,2.309448,47.040894,7.220703,...,10.676636,3.952759,9.656128,6.878662,10.557617,6.953125,12.807861,0.000000,6.793213,5.605347
ZSCAN20:P1P2,22.978149,43.406372,52.543945,8.077515,23.606445,6.753662,18.079712,5.961548,34.767456,14.157715,...,3.083130,7.702026,17.690308,4.183105,7.368652,8.998047,9.967529,6.793213,0.000000,6.688110


In [68]:
target_estats.to_csv(os.path.join(config["output_file_name_list"]["OUTPUT_FOLDER"],
                                  config["output_file_name_list"]["edist_target_by_target_matrix"]))

<h3>Use the best clustering parameter +visualization </h3>

In [69]:
#Embedding using tSNE
fit_method = TSNE(n_components=2, perplexity=2,n_iter=5000,random_state=1,
                  init="random",metric="precomputed")
embedding = fit_method.fit_transform(target_estats.copy())

total_edist_emb=pd.DataFrame(embedding,index=target_estats.index,columns=["x","y"]).reset_index()

In [70]:
# For now, only affinity propagation is supported
if config_clustering["clustering"]["method"]=="Affinity":
    clustering_method_emb = AffinityPropagation(random_state=0,
                                                convergence_iter=15,damping=0.50,
                                                affinity="precomputed"
                                               )
    cluster_info_emb = clustering_method_emb.fit(-target_estats)
else:
    clustering_method_emb = AffinityPropagation(random_state=0,
                                                convergence_iter=15,damping=0.50,
                                                affinity="precomputed"
                                               )
    cluster_info_emb = clustering_method_emb.fit(-target_estats)

total_edist_emb["cluster"] = cluster_info_emb.labels_

In [71]:
total_edist_emb.to_csv(os.path.join(config["output_file_name_list"]["OUTPUT_FOLDER"],
                                  config["output_file_name_list"]["edist_embedding_info"]))