In [8]:
import argparse
import os
import muon as mu
import pandas as pd
import pyranges as pr
import sys
import numpy as np

import matplotlib.pyplot as plt

from typing import Literal
from sklearn.metrics import (
    roc_auc_score, 
    average_precision_score, 
    roc_curve, 
    precision_recall_curve
)


from pathlib import Path

# Add the parent directory (assuming utils is in ../utils)
# Add the parent directory of the current working directory
sys.path.append(str(Path().resolve().parent))

from utils.benchmarking_metrics import *

GRN_TOOLS = Literal["scenicplus", "celloracle"]
SCORE_COL = "score_rank"
TF2GENE_W_COL = "TF2Gene_weight"

def modality_names(grn_tool:GRN_TOOLS, cell_type_col, sample):
    if grn_tool == "scenicplus":
        if "metacell_0" in sample:
            return "scRNA_counts", "scATAC_counts", f"scRNA_counts:{cell_type_col}"
        else:
            return "scRNA_counts", "scATAC_counts", "scRNA_counts:Metacell_Key"

def get_adata(mudata, grn_tool:GRN_TOOLS):
    if grn_tool == "scenicplus":
        return mudata["scRNA_counts"]
    else:
        raise ValueError(f"Unsupported GRN inference tool: {grn_tool}")


def get_mudata(path:str, grn_tool:GRN_TOOLS):
    if grn_tool == "scenicplus":
        scplus_mdata = mu.read_h5mu(path)
        return scplus_mdata
    else:
        raise ValueError(f"Unsupported GRN inference tool: {grn_tool}")

# GRNs need to have the following columns: "TF", "Gene", "Region"
def preprocess_scenicplus(scplus_mdata):
    # Extract metadata
    direct_df = pd.DataFrame(scplus_mdata.uns['direct_e_regulon_metadata'])
    extended_df = pd.DataFrame(scplus_mdata.uns['extended_e_regulon_metadata'])

    # Combine into one DataFrame
    grn = pd.concat([direct_df, extended_df], ignore_index=True)

    # Filter the relevant columns
    grn_filtered = grn[['Region', 'Gene', 'TF', 'importance_TF2G', 'importance_R2G','regulation', 'rho_TF2G', 'triplet_rank']].copy()

    # Split the 'Region' column into Chromosome, Start, End
    region_split = grn_filtered['Region'].str.extract(r'(chr[\w]+):(\d+)-(\d+)')
    region_split.columns = ['Chromosome', 'Start', 'End']

    # Convert Start and End to integers
    region_split['Start'] = region_split['Start'].astype(int)
    region_split['End'] = region_split['End'].astype(int)

    grn = pd.concat([region_split, grn_filtered], axis=1)

    max_rank = grn["triplet_rank"].max()
    min_rank = grn["triplet_rank"].min()

    grn[SCORE_COL] = (max_rank - grn["triplet_rank"]) / (max_rank - min_rank)

    raw = grn["importance_TF2G"] * grn["rho_TF2G"]
    grn[TF2GENE_W_COL] = np.tanh(3 * raw) 

    return grn

def get_grn_matrix(mudata, grn_tool:GRN_TOOLS):
    if grn_tool == "scenicplus":
        return preprocess_scenicplus(mudata)

def get_benchmark_matrix(tfb_path, prt_path, frc_path, gst_path, tfm_path):
    tfb_matrix = pd.read_csv(tfb_path, sep="\t", header=None)
    tfb_matrix.columns = ["Chromosome", "Start", "End", "TF"]

    prt_matrix = pd.read_csv(prt_path, sep=',')

    frc_matrix = pd.read_csv(frc_path, index_col=0, sep=',')

    gst_matrix = pd.read_csv(gst_path)

    tfm_matrix = df = pd.read_csv(tfm_path, header=None, names=['gene'])

    return tfb_matrix, prt_matrix, frc_matrix, gst_matrix, tfm_matrix





In [2]:
grn_file = "/data/tmpA/andrem/scenicplus/results/AML16_REM_metacell_2/scplusmdata.h5mu"

tfb_goldenf = "/data/benchmarks/andrem/data/pAML/benchmarks/tf-binding-test/dbs/tf_bind_db.bed"
prt_goldenf = "/data/benchmarks/andrem/data/pAML/benchmarks/tf-actvity/knock_tf_benchmark_data.csv"
frc_goldenf = "/data/benchmarks/andrem/data/pAML/benchmarks/tf-actvity/knock_tf_benchmark_data.csv"
gst_goldenf = "/data/benchmarks/andrem/data/pAML/benchmarks/gsets/merged_network.csv"
tfm_goldenf = "/data/benchmarks/andrem/data/pAML/benchmarks/tfm/filtered_tf_markers.txt"
grn_tool = "scenicplus"


In [3]:
mudata = get_mudata(grn_file, grn_tool)
grn_inferred = get_grn_matrix(mudata, grn_tool)
adata = get_adata(mudata, grn_tool)


In [4]:
tfb_golden, prt_golden, frc_golden, gst_matrix, tfm_matrix = get_benchmark_matrix(
    tfb_goldenf, 
    prt_goldenf, 
    frc_goldenf,
    gst_goldenf,
    tfm_goldenf
)


In [92]:
tfb_benchmark = tfb_test(
    grn_inferred=grn_inferred, 
    tf_binding_matrix=tfb_golden, 
    score_column=SCORE_COL
)


In [None]:
# Benchmark PRT
prt_benchmark = prt_test(
    grn_inferred=grn_inferred, 
    prt_matrix=prt_golden, 
    score_column=SCORE_COL,
    weight_column=TF2GENE_W_COL,
    step = 0.01
)


Threshold: 0.00
Threshold: 0.01
Threshold: 0.02
Threshold: 0.03
Threshold: 0.04
Threshold: 0.05
Threshold: 0.06
Threshold: 0.07
Threshold: 0.08
Threshold: 0.09
Threshold: 0.10
Threshold: 0.11
Threshold: 0.12
Threshold: 0.13
Threshold: 0.14
Threshold: 0.15
Threshold: 0.16
Threshold: 0.17
Threshold: 0.18
Threshold: 0.19
Threshold: 0.20
Threshold: 0.21
Threshold: 0.22
Threshold: 0.23
Threshold: 0.24
Threshold: 0.25
Threshold: 0.26
Threshold: 0.27
Threshold: 0.28
Threshold: 0.29
Threshold: 0.30
Threshold: 0.31
Threshold: 0.32
Threshold: 0.33
Threshold: 0.34
Threshold: 0.35
Threshold: 0.36
Threshold: 0.37
Threshold: 0.38
Threshold: 0.39
Threshold: 0.40
Threshold: 0.41
Threshold: 0.42
Threshold: 0.43
Threshold: 0.44
Threshold: 0.45
Threshold: 0.46
Threshold: 0.47
Threshold: 0.48
Threshold: 0.49
Threshold: 0.50
Threshold: 0.51
Threshold: 0.52
Threshold: 0.53
Threshold: 0.54
Threshold: 0.55
Threshold: 0.56
Threshold: 0.57
Threshold: 0.58
Threshold: 0.59
Threshold: 0.60
Threshold: 0.61
Threshol

<module 'utils.benchmarking_metrics' from '/home/andrem/GRN-project/workflow/utils/benchmarking_metrics.py'>

In [90]:
frc_benchmark = frc_test(
    grn_inferred=grn_inferred, 
    adata=adata,
    frc_matrix=frc_golden,
    score_column=SCORE_COL,
    step = 0.05
)

100%|██████████| 132/132 [01:19<00:00,  1.66it/s]


Threshold: 0.00


  8%|▊         | 11/132 [00:14<02:39,  1.32s/it]


KeyboardInterrupt: 

In [9]:
celltype_col = "Classified_Celltype"
sample = "AML16_REM_metacell_2"

rna_mod_name, atac_mod_name, celltype_col = modality_names(grn_tool, celltype_col, sample)
# Omics TF-Gene´
omics_tf2g = omic_test(
    grn_inferred = grn_inferred,
    mdata = mudata,
    score_column = SCORE_COL,
    source_column = "TF",
    target_column = "Gene",
    mod_source = rna_mod_name,
    mod_target= rna_mod_name,
    celltype_column= celltype_col,
    step = 0.3
)

  0%|          | 1/4745 [00:26<35:03:03, 26.60s/it]


KeyboardInterrupt: 

In [10]:
# Omics CRE-Gene
omics_r2g = omic_test(
    grn_inferred = grn_inferred,
    mdata = mudata,
    score_column = SCORE_COL,
    source_column = "Region",
    target_column = "Gene",
    mod_source = atac_mod_name,
    mod_target= rna_mod_name,
    celltype_column= celltype_col,
    step = 0.2
)

  1%|          | 45/4745 [00:03<06:39, 11.77it/s]


KeyboardInterrupt: 

In [11]:
# Omics CRE-TF
omics_r2tf = omic_test(
    grn_inferred = grn_inferred,
    mdata = mudata,
    score_column = SCORE_COL,
    source_column = "Region",
    target_column = "TF",
    mod_source = atac_mod_name,
    mod_target= rna_mod_name,
    celltype_column= celltype_col,
    step = 0.2
)

  0%|          | 0/40 [00:00<?, ?it/s]

 78%|███████▊  | 31/40 [00:03<00:00,  9.73it/s]


KeyboardInterrupt: 

In [52]:
import importlib
import utils.benchmarking_metrics 
from utils.benchmarking_metrics import *

importlib.reload(utils.benchmarking_metrics)

<module 'utils.benchmarking_metrics' from '/home/andrem/GRN-project/workflow/utils/benchmarking_metrics.py'>

In [53]:
gst_benchmark = gst_test(
    grn_inferred=grn_inferred, 
    ptw=gst_matrix, 
    rna=adata,
    score_column=SCORE_COL,
    step=0.01
)

Threshold: 0.01
Precision: 0.8788, Recall: 0.0420, F-beta: 0.7339
Threshold: 0.02
Precision: 0.8824, Recall: 0.0434, F-beta: 0.7407
Threshold: 0.03
Precision: 0.8824, Recall: 0.0434, F-beta: 0.7407
Threshold: 0.04
Precision: 0.8824, Recall: 0.0434, F-beta: 0.7407
Threshold: 0.05
Precision: 0.8824, Recall: 0.0434, F-beta: 0.7407
Threshold: 0.06
Precision: 0.8824, Recall: 0.0434, F-beta: 0.7407
Threshold: 0.07
Precision: 0.8824, Recall: 0.0434, F-beta: 0.7407
Threshold: 0.08
Precision: 0.8824, Recall: 0.0434, F-beta: 0.7407
Threshold: 0.09
Precision: 0.8824, Recall: 0.0434, F-beta: 0.7407
Threshold: 0.10
Precision: 0.8824, Recall: 0.0434, F-beta: 0.7407
Threshold: 0.11
Precision: 0.8824, Recall: 0.0434, F-beta: 0.7407
Threshold: 0.12
Precision: 0.8824, Recall: 0.0434, F-beta: 0.7407
Threshold: 0.13


KeyboardInterrupt: 

In [30]:
import importlib
import utils.benchmarking_metrics 
from utils.benchmarking_metrics import *

importlib.reload(utils.benchmarking_metrics)

<module 'utils.benchmarking_metrics' from '/home/andrem/GRN-project/workflow/utils/benchmarking_metrics.py'>

In [63]:
tfm_matrix

Unnamed: 0,gene
0,AEBP1
1,AHR
2,AHRR
3,AKNA
4,ALX1
...,...
846,ZNF143
847,ZNF398
848,ZNF45
849,ZNF692


In [54]:
tfm_benchmark = tfm_test(
    grn_inferred=grn_inferred, 
    db=tfm_matrix, 
    adata=adata,
    score_column=SCORE_COL,
    step=0.01
)


Threshold: 0.01
Precision: 0.7000, Recall: 0.0335, F-beta: 0.5849
Threshold: 0.02
Precision: 0.7000, Recall: 0.0335, F-beta: 0.5849
Threshold: 0.03
Precision: 0.7000, Recall: 0.0335, F-beta: 0.5849
Threshold: 0.04
Precision: 0.7000, Recall: 0.0335, F-beta: 0.5849
Threshold: 0.05
Precision: 0.7000, Recall: 0.0335, F-beta: 0.5849
Threshold: 0.06
Precision: 0.7000, Recall: 0.0335, F-beta: 0.5849
Threshold: 0.07
Precision: 0.7000, Recall: 0.0335, F-beta: 0.5849
Threshold: 0.08
Precision: 0.7000, Recall: 0.0335, F-beta: 0.5849
Threshold: 0.09
Precision: 0.7000, Recall: 0.0335, F-beta: 0.5849
Threshold: 0.10
Precision: 0.7000, Recall: 0.0335, F-beta: 0.5849
Threshold: 0.11
Precision: 0.7000, Recall: 0.0335, F-beta: 0.5849
Threshold: 0.12
Precision: 0.7000, Recall: 0.0335, F-beta: 0.5849
Threshold: 0.13
Precision: 0.7000, Recall: 0.0335, F-beta: 0.5849
Threshold: 0.14
Precision: 0.7000, Recall: 0.0335, F-beta: 0.5849
Threshold: 0.15
Precision: 0.7000, Recall: 0.0335, F-beta: 0.5849
Threshold:

In [67]:
tfm_benchmark

{'tp': 77,
 'fp': 33,
 'fn': 758,
 'precision': 0.7,
 'recall': 0.09221556886227544,
 'fbeta': 0.6571187156738486,
 'auroc': nan,
 'auprc': 0.0770899600044084,
 'best_threshold': 0.99,
 'best_fbeta': 0.7281903388608507,
 'best_precision': 0.819672131147541,
 'best_recall': 0.059880239520958084}

In [85]:
num_unique = grn_inferred["score_rank"].round(2).nunique()

In [86]:
num_unique

101

In [87]:
grn_inferred

Unnamed: 0,Chromosome,Start,End,Region,Gene,TF,importance_TF2G,importance_R2G,regulation,rho_TF2G,triplet_rank,score_rank,TF2Gene_weight
0,chr17,55672545,55673045,chr17:55672545-55673045,PCTP,ATF4,1.001423,0.020373,1,0.075926,4753,0.859723,0.224226
1,chr17,75969977,75970477,chr17:75969977-75970477,WBP2,ATF4,1.072571,0.032017,1,0.072392,23537,0.305345,0.228812
2,chr2,152611760,152612260,chr2:152611760-152612260,PRPF40A,ATF4,1.165450,0.061384,1,0.085792,8558,0.747425,0.291275
3,chr1,89956984,89957484,chr1:89956984-89957484,ZNF326,ATF4,1.345857,0.016973,1,0.076129,27887,0.176962,0.298047
4,chr1,43867883,43868383,chr1:43867883-43868383,ATP6V0B,ATF4,1.780124,0.021054,1,0.217779,11875,0.649529,0.822022
...,...,...,...,...,...,...,...,...,...,...,...,...,...
51643,chr18,3279128,3279628,chr18:3279128-3279628,MYL12A,TCF4,0.603821,0.076925,-1,-0.094790,1589,0.953103,-0.170041
51644,chr19,49443334,49443834,chr19:49443334-49443834,NOSIP,TCF4,0.407207,0.012764,-1,-0.072442,5020,0.851843,-0.088266
51645,chr17,63954340,63954840,chr17:63954340-63954840,PSMC5,TCF4,1.355978,0.024230,-1,-0.071852,6881,0.796919,-0.284241
51646,chr11,67444053,67444553,chr11:67444053-67444553,CORO1B,TCF4,1.190088,0.014211,-1,-0.071135,11508,0.660361,-0.248649
