In [1]:
import argparse
import os
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import celloracle as co
import mudata as mu
from types import SimpleNamespace


def parse_args():
    return SimpleNamespace(
        input="/data/tmpA/andrem/celloracle/preprocessed_obj/AML12_DX/rna_preprocessed.h5ad",
        base_grn="/data/tmpA/andrem/celloracle/preprocessed_obj/AML12_DX/base_GRN_dataframe.parquet",
        use_atac=True,
        output="/data/tmpA/andrem/celloracle/out/AML12_DX_with_atac/mdata.h5mu",
        downsample=30000
    )

def flatten_links_object(links_obj):
    """Convert CellOracle Links object to a flat DataFrame with cluster labels."""
    records = []
    for cluster_name, df in links_obj.links_dict.items():
        df = df.copy()
        df["cluster"] = cluster_name
        records.append(df)
    return pd.concat(records, ignore_index=True)


def expand_links_with_CREs(links, base_GRN):
    # Prepare base_GRN in long format: one row per (TF, gene)
    tf_columns = [col for col in base_GRN.columns if col not in ["gene_short_name", "chromosome", "start", "end"]]

    melted = base_GRN.melt(
        id_vars=["gene_short_name", "chromosome", "start", "end"],
        value_vars=tf_columns,
        var_name="source",
        value_name="score"
    )
    
    # Filter non-zero scores (active TF-gene interactions)
    melted = melted[melted["score"] != 0]

    # Merge links with base_GRN annotations
    annotated = pd.merge(
        links,
        melted,
        how="left",
        left_on=["source", "target"],
        right_on=["source", "gene_short_name"]
    )

    # Drop extra column
    annotated = annotated.drop(columns=["score", "gene_short_name"])

    # For rows with no CRE info, fill NaNs as before
    for col in ["chromosome", "start", "end"]:
        if col not in annotated.columns:
            annotated[col] = np.nan

    return annotated

In [2]:
args = parse_args()

# Load scRNA-seq data
adata = sc.read_h5ad(args.input)

# Downsample cells if needed
if adata.shape[0] > args.downsample:
    sc.pp.subsample(adata, n_obs=args.downsample, random_state=123)


In [2]:
base_GRN = co.data.load_human_promoter_base_GRN(version="hg38_gimmemotifsv5_fpr2")

Loading prebuilt promoter base-GRN. Version: hg38_gimmemotifsv5_fpr2


In [3]:
base_GRN

Unnamed: 0,peak_id,gene_short_name,9430076C15RIK,AC002126.6,AC012531.1,AC226150.2,AFP,AHR,AHRR,AIRE,...,ZNF784,ZNF8,ZNF816,ZNF85,ZSCAN10,ZSCAN16,ZSCAN22,ZSCAN26,ZSCAN31,ZSCAN4
0,chr10_100009853_100010953,DNMBP,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,chr10_100081785_100082885,CPN1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
2,chr10_100185877_100186977,ERLIN1,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,chr10_100186978_100187057,ERLIN1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,chr10_100229510_100230610,CHUK,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
39310,chrY_9721196_9722296,TTTY21B,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
39311,chrY_9735286_9736386,TTTY2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
39312,chrY_9774219_9775319,TTTY1B,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
39313,chrY_9800153_9801253,TTTY22,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [3]:

# Load base GRN
if args.use_atac:
    if args.base_grn is None:
        raise ValueError("Please provide --base_grn when using --use_atac")
    print(f"Loading base GRN from {args.base_grn}")
    base_GRN = pd.read_parquet(args.base_grn)
else:
    print("Using default base GRN for human promoter data.")
    base_GRN = co.data.load_human_promoter_base_GRN(version="hg38_gimmemotifsv5_fpr2")

if "peak_id" in base_GRN.columns:
    peak_coords = base_GRN["peak_id"].str.split("_", expand=True)

    # Then validate that 'start' and 'end' columns are numeric
    is_numeric_start = peak_coords[1].str.isnumeric()
    is_numeric_end = peak_coords[2].str.isnumeric()
    valid_rows = is_numeric_start & is_numeric_end

    if not valid_rows.any():
        raise ValueError("No valid 'peak_id' entries with numeric coordinates found.")

    peak_coords = peak_coords[valid_rows]
    base_GRN = base_GRN.loc[valid_rows].copy()
    base_GRN["chromosome"] = peak_coords[0].values
    base_GRN["start"] = peak_coords[1].astype(int).values
    base_GRN["end"] = peak_coords[2].astype(int).values
else:
    raise ValueError("Expected 'peak_id' column in base_GRN, but it was not found.")



Loading base GRN from /data/tmpA/andrem/celloracle/preprocessed_obj/AML12_DX/base_GRN_dataframe.parquet


In [5]:
# Instantiate Oracle
oracle = co.Oracle()
adata.X = adata.layers["raw_count"].copy()
print("Metadata columns :", list(adata.obs.columns))
print("Dimensional reduction: ", list(adata.obsm.keys()))

oracle.import_anndata_as_raw_count(
    adata=adata,
    cluster_column_name="louvain",
    embedding_name="X_draw_graph_fa"
)

oracle.import_TF_data(TF_info_matrix=base_GRN)

oracle.perform_PCA()


Metadata columns : ['patient', 'dx', 'sample', 'batch', 'GEO_ID', 'Lambo_et_al_ID', 'Patient_Sample', 'Library_ID', 'Counts', 'Features', 'Mitochondria_percent', 'Classified_Celltype', 'Seurat_Cluster', 'Malignant', 'Patient_ID', 'Biopsy_Origin', 'Age_Months', 'Disease_free_days', 'Clinical_Blast_Percent', 'Expected_Driving_Aberration', 'Subgroup', 'Color_Subgroup', 'Known_CNVs', 'Treatment_Outcome', 'nCount_RNA', 'nFeature_RNA', 'doublet', 'doublet_score', 'n_counts_all', 'n_counts', 'louvain']
Dimensional reduction:  ['X_diffmap', 'X_draw_graph_fa', 'X_pca']


In [6]:

# Estimate number of PCs
plt.plot(np.cumsum(oracle.pca.explained_variance_ratio_)[:100])
diffs = np.diff(np.diff(np.cumsum(oracle.pca.explained_variance_ratio_)) > 0.002)
n_comps = np.where(diffs)[0][0] if np.any(diffs) else 30
n_comps = min(n_comps, 50)

n_cell = oracle.adata.shape[0]
print(f"Cell number: {n_cell}")
k = int(0.025 * n_cell)
print(f"Auto-selected k: {k}")


Cell number: 6826
Auto-selected k: 170


In [7]:

oracle.knn_imputation(
    n_pca_dims=n_comps,
    k=k,
    balanced=True,
    b_sight=k * 8,
    b_maxl=k * 4,
    n_jobs=4
)


In [8]:

print("Inferring GRNs...")
raw_links = oracle.get_links(
    cluster_name_for_GRN_unit="louvain",
    alpha=10,
    verbose_level=10
)


Inferring GRNs...


  0%|          | 0/32 [00:00<?, ?it/s]

Inferring GRN for 0...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 1...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 10...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 11...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 12...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 13...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 14...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 15...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 16...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 17...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 18...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 19...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 2...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 20...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 21...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 22...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 23...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 24...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 25...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 26...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 27...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 28...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 29...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 3...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 30...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 31...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 4...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 5...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 6...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 7...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 8...


  0%|          | 0/1328 [00:00<?, ?it/s]

Inferring GRN for 9...


  0%|          | 0/1328 [00:00<?, ?it/s]

In [12]:

print(f"Saving links to {args.output}")
link_output_path = Path(args.output).parent / "links.celloracle.links"
# Save Links object.
raw_links.to_hdf5(file_path=str(link_output_path))

raw_links.to_hdf5(file_path=link_output_path)


Saving links to /data/tmpA/andrem/celloracle/out/AML12_DX/without_atac.h5mu


FileNotFoundError: [Errno 2] Unable to synchronously create file (unable to open file: name = '/data/tmpA/andrem/celloracle/out/AML12_DX/links.celloracle.links', errno = 2, error message = 'No such file or directory', flags = 13, o_flags = 242)

In [19]:

print("Flattening GRN links...")
links = flatten_links_object(raw_links)


Flattening GRN links...


In [25]:

print("Annotating links with CREs...")
links = expand_links_with_CREs(links, base_GRN)


Annotating links with CREs...


In [15]:

print("Creating MuData object...")
modalities = {"rna": adata}
if args.use_atac:
    atac_mod = sc.AnnData(X=np.empty((adata.n_obs, 0)))  # Placeholder
    modalities["atac"] = atac_mod

mdata = mu.MuData(modalities)
mdata.uns["celloracle_links"] = links

output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)


Creating MuData object...


In [16]:

print(f"Saving MuData to {args.output}")
mdata.write(args.output)


Saving MuData to /data/tmpA/andrem/celloracle/out/AML12_DX/without_atac.h5mu


In [26]:
links

Unnamed: 0,source,target,coef_mean,coef_abs,p,-logp,cluster,chromosome,start,end
0,FOSL2,A2M-AS1,1.308608e-03,1.308608e-03,8.212768e-09,8.085510,0,chr12,9064970,9066070
1,JUNB,A2M-AS1,-3.645662e-04,3.645662e-04,1.979343e-02,1.703479,0,chr12,9064970,9066070
2,KLF10,A2M-AS1,-6.443677e-04,6.443677e-04,3.456784e-05,4.461328,0,chr12,9064970,9066070
3,FOSB,A2M-AS1,-7.718563e-07,7.718563e-07,9.948294e-01,0.002251,0,chr12,9064970,9066070
4,FOS,A2M-AS1,9.604866e-04,9.604866e-04,2.387355e-08,7.622083,0,chr12,9064970,9066070
...,...,...,...,...,...,...,...,...,...,...
1377947,EGR1,ZWINT,0.000000e+00,0.000000e+00,,-0.000000,9,chr10,56361173,56362273
1377948,ARID5A,ZWINT,0.000000e+00,0.000000e+00,,-0.000000,9,chr10,56361173,56362273
1377949,FOXS1,ZWINT,0.000000e+00,0.000000e+00,,-0.000000,9,chr10,56361173,56362273
1377950,MYC,ZWINT,0.000000e+00,0.000000e+00,,-0.000000,9,chr10,56361173,56362273


In [24]:
mdata