# Run CellOracle

run CellOracle on Scenic output to obtain cell type specific networks

In [None]:
import re
from contextlib import redirect_stdout
import logging as log
from pathlib import Path
import yaml

import numpy as np
import scipy as sp
import pandas as pd
import scanpy as sc
import celloracle as co

import matplotlib.pyplot as plt
from IPython.display import display, Markdown

In [None]:
logger = log.getLogger()

In [None]:
log.info(f"CellOracle version: {co.__version__}")

In [None]:
%matplotlib inline

## Params

input

In [None]:
metacell_rna_h5ad = "/path/to/rna_metacells.h5ad"

output

In [None]:
celloracle_obj_path = "/path/to/celloracle.oracle"
links_obj_path = "/path/to/celloracle.links"
links_obj_filtered_path = "/path/to/celloracle_filtered.links"

params

In [None]:
regulon_paths = ["regulons1.yaml"]  # paths to yaml files with regulons
cell_type_annot = "cell_type_obs_column"

gene_subset = "HVG"
knn_process = "diffmap"

In [None]:
assert gene_subset in ["HVG", "full_regulon"]

In [None]:
assert knn_process in ["diffmap", "none", ""]

## 1) Load regulons

In [None]:
log.info("load regulons")

In [None]:
regulons = {}

for p in regulon_paths:
    with open(p, "r") as f:
        log.info(str(p))
        reg_tmp = yaml.safe_load(f)
        reg_tmp = {re.sub("([^()_]+).*", "\\1", k):v for k,v in reg_tmp.items()}
        for k, v in reg_tmp.items():
            if k in regulons:
                regulons[k].extend(v)
            else:
                regulons[k] = v

In [None]:
for k, v in list(regulons.items())[:5]:
    log.info(f"{k}: {', '.join(v[:min(5,len(v))])}")

## 2) Load RNA anndata

In [None]:
log.info("load rna anndata object")

In [None]:
ad = sc.read_h5ad(metacell_rna_h5ad)

**need raw counts**

In [None]:
if ad.X.max() < 100:
    if ad.raw.X.max() < 100:
        raise ValueError("CellOracle needs raw counts")
    else:
        log.info("get raw counts from .raw")
        ad = ad.raw.to_adata()

### subset HVGs

recommended for CellOracle

In [None]:
log.info("select hvg and normalise")

In [None]:
reg_genes = set([k for k in regulons.keys()] + [x for v in regulons.values() for x in v])

In [None]:
len(reg_genes)

In [None]:
sc.pp.filter_genes(ad, min_counts=1)
sc.pp.normalize_total(ad)

if gene_subset == "HVG":
    filter_result = sc.pp.filter_genes_dispersion(
        ad.X,
        flavor = 'cell_ranger',
        n_top_genes = 2000,
        log = False,
    )

    filter_result2 = sc.pp.filter_genes_dispersion(
        ad.X,
        flavor = 'cell_ranger',
        n_top_genes = min(20000, ad.X.shape[1]),
        log = False,
    )

    gene_select = ad.var_names[filter_result.gene_subset].tolist()
    gene_select += list(
        set(regulons) & 
        set(ad.var_names[filter_result2.gene_subset].tolist())
    )
    gene_select = list(set(gene_select))
elif gene_subset == "full_regulon":
    gene_select = list(reg_genes)
    gene_select = list(set(gene_select) & set(ad.var_names))

log.info(f"selected {len(gene_select)} genes")

# Subset the genes
ad = ad[:, gene_select]

sc.pp.filter_genes(ad, min_cells=ad.X.shape[0]*0.1)

# Renormalize after filtering
sc.pp.normalize_per_cell(ad)

### save raw counts and log-transform

In [None]:
log.info("log transform and scale")

In [None]:
# keep raw cont data before log transformation
ad.raw = ad
ad.layers["raw_count"] = ad.raw.X.copy()

# Log transformation and scaling
sc.pp.log1p(ad)
sc.pp.scale(ad)

### PCA and knn

In [None]:
log.info("pca and knn")

In [None]:
# PCA
sc.tl.pca(ad, svd_solver='arpack')

if knn_process == "diffmap":
    # Diffusion map
    sc.pp.neighbors(ad, n_neighbors = 10, n_pcs = 20)

    sc.tl.diffmap(ad)
    # Calculate neihbors again based on diffusionmap
    sc.pp.neighbors(ad, n_neighbors = 15, use_rep='X_diffmap')
else:
    sc.pp.neighbors(ad, n_neighbors = 30, n_pcs = 20)

# clustering
sc.tl.leiden(ad, resolution = 1.0)

### PAGA and FA embedding

In [None]:
log.info("paga and fa embedding")

In [None]:
sc.tl.paga(ad, groups = cell_type_annot)

In [None]:
sc.pl.paga(ad, threshold = 0.2)

In [None]:
if "X_draw_graph_fa" not in ad.obsm:
    sc.tl.draw_graph(ad, init_pos = 'paga', random_state = 123)

In [None]:
plot_cols = [x for x in ['leiden', cell_type_annot, 'batch', 'n_counts'] if x in ad.obs.columns.tolist()]

In [None]:
sc.pl.draw_graph(ad, color=plot_cols, ncols=1, legend_loc='on data', save="test")

## 3) Setup CellOracle Object

In [None]:
log.info("setup celloracle object")

In [None]:
oracle = co.Oracle()

In [None]:
ad.X = ad.layers["raw_count"].copy()

oracle.import_anndata_as_raw_count(
    ad,
    cluster_column_name = cell_type_annot,
    embedding_name = "X_draw_graph_fa",
)

In [None]:
TG_to_TF_dictionary = co.utility.inverse_dictionary(regulons)

oracle.import_TF_data(TFdict = TG_to_TF_dictionary)

In [None]:
oracle.to_hdf5(celloracle_obj_path)

### preprocessing

In [None]:
log.info("celloracle preprocessing")

**PCA**

In [None]:
# Perform PCA
oracle.perform_PCA()

# Select important PCs
plt.plot(np.cumsum(oracle.pca.explained_variance_ratio_)[:100])
n_comps = np.where(np.diff(np.diff(np.cumsum(oracle.pca.explained_variance_ratio_))>0.002))[0][0]
plt.axvline(n_comps, c="k")
plt.show()
log.info(n_comps)
n_comps = min(n_comps, 50)

In [None]:
log.info(f"chosen pca comp: {n_comps}")

In [None]:
n_cell = oracle.adata.shape[0]
log.info(f"cell number is: {n_cell}")

k = int(0.025*n_cell)
log.info(f"Auto-selected k is: {k}")

**knn-imputation**

In [None]:
oracle.knn_imputation(
    n_pca_dims = n_comps, 
    k = k, 
    balanced = True, 
    b_sight = k*8,
    b_maxl = k*4, 
    n_jobs = 4,
)

In [None]:
oracle.to_hdf5(celloracle_obj_path)

## 4) GRN filtering

In [None]:
log.info("grn fitting")

In [None]:
with open(str(snakemake.log.logging), "a") as f:
    with redirect_stdout(f):
        links = oracle.get_links(
            cluster_name_for_GRN_unit = cell_type_annot,
            alpha = 10,
            verbose_level = 10
        )

In [None]:
links.to_hdf5(file_path = links_obj_path)

## 5) Network processing

In [None]:
log.info("filter links")

In [None]:
links.filter_links(p=0.001, weight="coef_abs", threshold_number=2000)

In [None]:
plt.rcParams["figure.figsize"] = [9, 4.5]

In [None]:
try:
    links.plot_degree_distributions(
        plot_model = True,
        #save=f"{save_folder}/degree_distribution/",
    )
except Exception:
    log.exception("could not plot degree distribution")

In [None]:
try:
    links.get_network_score()
    links.merged_score.head()
except Exception:
    log.exception("could not get network scores")

In [None]:
links.to_hdf5(file_path = links_obj_filtered_path)