# CellOracle GRN velocity

run CellOracle to obtain GRN velocities from cell type specific networks

In [None]:
import re
import logging as log
from pathlib import Path
import yaml

import numpy as np
import scipy as sp
import pandas as pd
import scanpy as sc
import scvelo as scv
import celloracle as co

import matplotlib.pyplot as plt
from IPython.display import display, Markdown

In [None]:
logger = log.getLogger()

In [None]:
log.info(f"CellOracle version: {co.__version__}")

In [None]:
%matplotlib inline

## Params

In [None]:
celloracle_obj_path = "/path/to/celloracle.oracle"
links_obj_path = "/path/to/celloracle.links"

# output
celloracle_obj_simresult_path = "/path/to/simulation_result.oracle"

cell_type_annot = "cell_type_obs_column"
goi = "RUNX2"  # gene of interest
perturb_type = "KO"

scale = 25
scale_sim = 0.5
min_mass = 0.005

In [None]:
assert perturb_type in ["KO", "overexpression", "topTF"]

## 1) Load

In [None]:
log.info("load celloracle object and links")

In [None]:
oracle = co.load_hdf5(celloracle_obj_path)

In [None]:
links = co.load_hdf5(links_obj_path)

In [None]:
oracle.adata

## 2) Fit predictive models for state specific networks

In [None]:
log.info("fit predictive models")

In [None]:
links.filter_links()

In [None]:
oracle.get_cluster_specific_TFdict_from_Links(links_object=links)

In [None]:
oracle.fit_GRN_for_simulation(alpha=10, use_cluster_specific_TFdict=True)

## 3) Simulate gene perturbation

### plot UMAP and gene expr distribution

In [None]:
log.info("plot umap")

In [None]:
plot_fields = [oracle.cluster_column_name]
if goi in oracle.adata.var_names:
    plot_fields.append(goi)

sc.pl.draw_graph(
    oracle.adata, 
    color = plot_fields,
    layer = "imputed_count", 
    use_raw = False, 
    cmap = "viridis"
)

In [None]:
try:
    # Plot gene expression in histogram
    sc.get.obs_df(oracle.adata, keys=[goi], layer="imputed_count").hist()
    plt.show()
except Exception:
    log.exception(f"could not plot gene expression for {goi}")

### simulate perturbation

In [None]:
if perturb_type == "topTF":
    log.info("simulate perturbation per cell type using top TFs")

    sc.tl.rank_genes_groups(oracle.adata, groupby=cell_type_annot, method="wilcoxon")

    delta_x = None
    sim_count = None

    for grp in oracle.adata.obs[cell_type_annot].unique().tolist():
        log.info(f"simulate shift for {grp}")

        rank_df = sc.get.rank_genes_groups_df(oracle.adata, group=grp)

        scr = rank_df[rank_df.names.isin(oracle.active_regulatory_genes)].set_index("names")["scores"]
        scr -= scr.min()
        scr /= scr.max()
        scr *= 2
        scr -= 1

        max_vals = sc.get.obs_df(oracle.adata, keys=scr.index.tolist(), layer="imputed_count").max()
        min_vals = sc.get.obs_df(oracle.adata, keys=scr.index.tolist(), layer="imputed_count").min()
        mean_vals = sc.get.obs_df(oracle.adata, keys=scr.index.tolist(), layer="imputed_count").mean()
        diff_max = max_vals - mean_vals
        diff_min = mean_vals - min_vals

        scr = {
            x: mean_vals[x]+diff_max[x]*scr[x] if scr[x]>0 else mean_vals[x]+diff_min[x]*scr[x]
            for x in scr.index[:10].tolist()
        }

        oracle.simulate_shift(
            perturb_condition = scr,
            n_propagation=3
        )

        mask = oracle.adata.obs[cell_type_annot] == grp

        if sim_count is None:
            sim_count = oracle.adata.layers["simulated_count"]
        else:
            sim_count[mask] = oracle.adata[mask].layers["simulated_count"]

        if delta_x is None:
            delta_x = oracle.adata.layers["delta_X"]
        else:
            delta_x[mask] = oracle.adata[mask].layers["delta_X"]

    oracle.adata.layers["simulated_count"] = sim_count
    oracle.adata.layers["delta_X"] = delta_x

In [None]:
if perturb_type == "KO":
    oracle.simulate_shift(
        perturb_condition={goi: 0.0},
        n_propagation=3
    )

In [None]:
if perturb_type == "overexpression":
    max_val = sc.get.obs_df(oracle.adata, keys=[goi], layer="imputed_count").max()[0]
    oracle.simulate_shift(
        perturb_condition={goi: max_val},
        n_propagation=3
    )

### transition probabilities and embedding

In [None]:
# Get transition probability
oracle.estimate_transition_prob(
    n_neighbors=50,
    knn_random=True,
    sampled_fraction=1
)

In [None]:
# Calculate embedding
oracle.calculate_embedding_shift(sigma_corr=0.05)

### save celloracle object

In [None]:
oracle.to_hdf5(celloracle_obj_simresult_path)

## 4) Plot velocities

In [None]:
log.info("plot velocities")

In [None]:
fig, ax = plt.subplots(1, 2,  figsize=[13, 6])

scale = scale
# Show quiver plot
oracle.plot_quiver(scale=scale, ax=ax[0])
ax[0].set_title(f"Simulated cell identity shift vector: {goi} {perturb_type}")

# Show quiver plot that was calculated with randomized graph.
oracle.plot_quiver_random(scale=scale, ax=ax[1])
ax[1].set_title(f"Randomized simulation vector")

plt.show()

### plot on a grid

adjust grid

In [None]:
log.info("setup grid")

In [None]:
n_grid = 40
oracle.calculate_p_mass(smooth=0.8, n_grid=n_grid, n_neighbors=50)

In [None]:
oracle.suggest_mass_thresholds(n_suggestion=12)

In [None]:
min_mass = min_mass
oracle.calculate_mass_filter(min_mass=min_mass, plot=True)

plot

In [None]:
log.info("plot velocities on grid")

In [None]:
fig, ax = plt.subplots(1, 2,  figsize=[13, 6])

scale_simulation = scale_sim
# Show quiver plot
oracle.plot_simulation_flow_on_grid(scale=scale_simulation, ax=ax[0])
ax[0].set_title(f"Simulated cell identity shift vector: {goi} {perturb_type}")

# Show quiver plot that was calculated with randomized graph.
oracle.plot_simulation_flow_random_on_grid(scale=scale_simulation, ax=ax[1])
ax[1].set_title(f"Randomized simulation vector")

plt.show()

In [None]:
# Plot vector field with cell cluster
fig, ax = plt.subplots(figsize=[8, 8])

oracle.plot_cluster_whole(ax=ax, s=10)
oracle.plot_simulation_flow_on_grid(scale=scale_simulation, ax=ax, show_background=False)
ax.set_title(f"Simulated cell identity shift vector: {goi} {perturb_type}")

Path("celloracle_perturbation_plots").mkdir(parents=False, exist_ok=True)
plt.savefig(f"celloracle_perturbation_plots/{goi}_{perturb_type}.pdf", dpi=400)

### plot on PAGA graph

In [None]:
try:
    scv.tl.velocity_graph(oracle.adata, vkey="delta_X", xkey="simulation_input")
    scv.tl.paga(oracle.adata, groups=cell_type_annot, vkey="delta_X")
except Exception:
    log.exception("could not calculate velocity graph for PAGA")

In [None]:
try:
    scv.pl.paga_compare(oracle.adata, color=cell_type_annot, transitions="transitions_confidence", fontoutline=1.5)
except Exception:
    log.exception("could not plot")

In [None]:
try:
    scv.pl.paga(oracle.adata, dashed_edges=None, edge_width_scale=2.0, legend_loc="on data", fontoutline=1.5)
except Exception:
    log.exception("could not plot")

In [None]:
try:
    plt.rcParams['pdf.fonttype'] = 42
    plt.rcParams['ps.fonttype'] = 42
    plt.rcParams['svg.fonttype'] = "none"

    Path("celloracle_perturbation_plots").mkdir(parents=False, exist_ok=True)

    scv.pl.velocity_embedding_stream(
        oracle.adata, vkey="delta_X", basis="X_draw_graph_fa", color="annot_v4", 
        title=f"Simulated cell identity shift vector: {goi} {perturb_type}", 
        linewidth=3, alpha=0.1,
        save = f"celloracle_perturbation_plots/{goi}_{perturb_type}.svg",
    )
except Exception:
    log.exception("could not plot embedding stream")


# plt.savefig(f"celloracle_perturbation_plots/{goi}_{perturb_type}.pdf", dpi=400)