# CellOracle perturbation scores

compute perturbation scores by comparing CellOracle vectors to differentiation vectors obtained from other trajectory based method

In [None]:
import re
import logging as log
from pathlib import Path
import yaml

import numpy as np
import scipy as sp
import pandas as pd
import scanpy as sc
import celloracle as co
from celloracle.applications import Oracle_development_module

import matplotlib.pyplot as plt
from IPython.display import display, Markdown

In [None]:
logger = log.getLogger()

In [None]:
log.info(f"CellOracle version: {co.__version__}")

In [None]:
%matplotlib inline

## Params

In [None]:
## input
celloracle_obj_simresult_path = "/path/to/simulation_result.oracle"
gradient_obj_path = "/path/to/gradient_object.gradient"

## output
perturbation_score_path = "/path/to/perturbation_scores_1.csv"

## params
vm = 0.02
scale_simulation = 0.5
scale_dev = 40
cluster_col = "cell_type_obs_column"

## 1) Load

In [None]:
log.info("load celloracle object")

In [None]:
oracle = co.load_hdf5(celloracle_obj_simresult_path)

In [None]:
log.info("load gradient object")

In [None]:
gradient = co.load_hdf5(gradient_obj_path)

## 2) Calculate inner product of vectors

In [None]:
log.info("create Oracle_development_module object")

In [None]:
# Make Oracle_development_module to compare two vector field
dev = Oracle_development_module()

In [None]:
# Load development flow
dev.load_differentiation_reference_data(gradient_object = gradient)

In [None]:
# Load simulation result
dev.load_perturb_simulation_data(oracle_object = oracle)

In [None]:
log.info("calculate inner product")

In [None]:
# Calculate inner produc scores
dev.calculate_inner_product()
dev.calculate_digitized_ip(n_bins=10)

## 3) Visualise perturbation scores

### all lineages

In [None]:
fig, ax = plt.subplots(1, 2, figsize=[12, 6])
dev.plot_inner_product_on_grid(vm=vm, s=50, ax=ax[0])
ax[0].set_title(f"PS")

dev.plot_inner_product_random_on_grid(vm=vm, s=50, ax=ax[1])
ax[1].set_title(f"PS calculated with Randomized simulation vector")
plt.show()

In [None]:
# Show perturbation scores with perturbation simulation vector field
fig, ax = plt.subplots(figsize=[6, 6])
dev.plot_inner_product_on_grid(vm = vm, s = 50, ax = ax)
dev.plot_simulation_flow_on_grid(scale = scale_simulation, show_background = False, ax = ax)

In [None]:
dev.visualize_development_module_layout_0(
    s = 5,
    scale_for_simulation = scale_simulation,
    s_grid = 50,
    scale_for_pseudotime = scale_dev,
    vm = vm
)

## 4) Calculate per cell type / lineage

In [None]:
mean_perturbation_scores = {}

for grp in oracle.adata.obs[cluster_col].unique().tolist():
    # Get cell index list for the cells of interest
    clusters = [grp]
    cluster_col = cluster_col
    lineage_name = grp

    try:
        cell_idx = np.where(oracle.adata.obs[cluster_col].isin(clusters))[0]
    except Exception:
        log.exception(f"could not select cells for {grp}")

    try:
        dev = Oracle_development_module()

        # Load development flow
        dev.load_differentiation_reference_data(gradient_object = gradient)

        # Load simulation result
        dev.load_perturb_simulation_data(
            oracle_object = oracle,
            cell_idx_use = cell_idx, 
            name = lineage_name,
        )
    except Exception:
        log.exception(f"could not load data {grp}")

    try:
        # Calculation
        dev.calculate_inner_product()
        dev.calculate_digitized_ip(n_bins = 10)
    except Exception:
        log.exception(f"could not calculate PS {grp}")

    try:
        dev.visualize_development_module_layout_0(
            s = 5,
            scale_for_simulation = scale_simulation,
            s_grid = 50,
            scale_for_pseudotime = scale_dev,
            vm = vm,
        )
    except Exception:
        log.exception(f"could not plot {grp}")

    try:
        mean_perturbation_scores[grp] = dev.inner_product_df.score.mean()
    except Exception:
        log.exception(f"could not save PS {grp}")

In [None]:
ps_df = pd.DataFrame(mean_perturbation_scores, index=["PS"]).T

In [None]:
ps_df.plot.bar(title = "perturbation scores per cell type")

## 5) Save

In [None]:
ps_df.to_csv(perturbation_score_path)