In [None]:
import os
import sys

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scanpy as sc
import seaborn as sns

In [None]:
import celloracle as co

In [None]:
plt.rcParams["figure.figsize"] = [6,6]
%config InlineBackend.figure_format = 'retina'
plt.rcParams["savefig.dpi"] = 600

%matplotlib inline

In [None]:
save_folder = "figures_celloracle"
os.makedirs(save_folder, exist_ok=True)

In [None]:
adata = sc.read_h5ad("KO.h5ad")
adata

In [None]:
sc.pl.embedding(adata, basis="X_mde_scanvi", color=["CellType","Author","DevTP"], ncols=1, frameon=False, )


In [None]:
adata

In [None]:
sc.pl.umap(adata, color=["Sox9"], use_raw=False)

In [None]:
print(f"Cell number is: {adata.shape[0]}")
print(f"Gene number is: {adata.shape[1]}")

In [None]:
n_cells_downsample = 20000

if adata.shape[0] > n_cells_downsample:
    # Let's dowmsample into 20K cells
    sc.pp.subsample(adata, n_obs=n_cells_downsample, random_state=123)

In [None]:
print(f"Cell number is: {adata.shape[0]}")

In [None]:
# Call top variable genes
adata2 = adata.copy()
sc.pp.log1p(adata2)
sc.pp.highly_variable_genes(adata2, n_top_genes=5000)

In [None]:
sc.pl.highly_variable_genes(adata2)

In [None]:
# Keep only highly variable genes
adata = adata[:, adata2.var.highly_variable]

In [None]:
print(f"Gene number is: {adata.shape[1]}")

For the GRN inference, celloracle needs a base-GRN. There are several ways to make one, the recommended being scATAC-seq data generated within the same experiment. Since we don't have that, we'll use the second option: the base-GRN made from the [sciATAC-seq atlas](http://atlas.gs.washington.edu/mouse-atac/).

In [None]:
# Load TF info which was made from mouse cell atlas dataset.
base_GRN = co.data.load_mouse_scATAC_atlas_base_GRN()

# Check data
base_GRN.head()

### Initiate Oracle object

Oracle is used for the data preprocessing and GRN inference steps. The Oracle Object stores all information and does the calculations with its internal functions. First, we instantiate an Oracle object, then put the gene expression data (anndata) and TF info into the object.

In [None]:
# Instantiate Oracle object
oracle = co.Oracle()

For the celloracle analysis, the anndata shoud include (1) gene expression count, (2) clustering information, (3) trajectory (dimensional reduction embeddings) data.

When you load a scRNA-seq data, please enter **the name of clustering data and dimensional reduction data**. 
- The clustering data should be to be stored in the attribute of obs in the anndata. Thic can be checked by the following command: `adata.obs.columns`.
- Dimensional reduction data suppose to be stored in the attribute of “obsm” in the anndata. This can be checked by the following command: `adata.obs.keys`.

In [None]:
adata.layers['raw_count'] = adata.X.copy()
adata

In [None]:
# In this notebook, we use raw mRNA count as an input of Oracle object.
adata.X = adata.layers['raw_count'].copy()

# Instantiate Oracle object.
oracle.import_anndata_as_raw_count(adata=adata,
                                   cluster_column_name="CellType",
                                   embedding_name="X_umap")

In [None]:
# Load TF info dataframe with the following code.
oracle.import_TF_data(TF_info_matrix=base_GRN)

### k-NN imputation
Celloracle uses the same strategy as velocyto for visualizing cell transitions. This process requires KNN imputation in advance.
For the KNN imputation, we need to first perform PCA and PC selection.

In [None]:
# Perform PCA
oracle.perform_PCA()

# Select important PCs
plt.plot(np.cumsum(oracle.pca.explained_variance_ratio_)[:100])
n_comps = np.where(np.diff(np.diff(np.cumsum(oracle.pca.explained_variance_ratio_))>0.002))[0][0]
plt.axvline(n_comps, c="k")
print(n_comps)
n_comps = min(n_comps, 50)

Estimate the optimal number of nearest neighbors for KNN imputation.

In [None]:
n_cell = oracle.adata.shape[0]
print(f"cell number is: {n_cell}")

k = int(0.025*n_cell)
print(f"Auto-selected k is: {k}")

In [None]:
oracle.knn_imputation(n_pca_dims=n_comps, k=k, balanced=True, b_sight=k*8,
                      b_maxl=k*4, n_jobs=4)

### GRN calculation

The next step is constructing a cluster-specific GRN for all clusters.

- GRNs are calculated with the `get_links` function, and the function returns GRNs as a `Links` object. The `Links` object stores inferred GRNs and the corresponding metadata. Network analysis can be performed with on the `Links` object.

- The GRN will be calculated for each cluster/sub-group.

In [None]:
sc.pl.umap(adata,  color=['CellType'],
           legend_loc='right margin')

In [None]:
%%time
# Calculate GRN for each population in "louvain_annot" clustering unit.
# This step may take long time.
links = oracle.get_links(cluster_name_for_GRN_unit="CellType", alpha=10,
                         verbose_level=2, test_mode=False)

### in silico TF Perturbation analysis
Next, we will simulate the TF perturbation effects on cell identity to investigate its function and regulatory mechanism. See the celloracle paper for the details and scientific premise on the algorithm.

In this notebook, we’ll simulate knock-out of the Myog gene in the myogenesis trajectory.

In [None]:
plt.rcParams["figure.figsize"] = [6, 4.5]

In [None]:
oracle.adata.var.index

In [None]:
goi= "Sox9"
sc.pl.umap(oracle.adata, color=[goi, oracle.cluster_column_name],
                 layer="imputed_count", use_raw=False, cmap='viridis')

In [None]:
sc.get.obs_df(oracle.adata, keys=[goi], layer="imputed_count").hist()
plt.show()

In [None]:
links.filter_links()
oracle.get_cluster_specific_TFdict_from_Links(links_object=links)
oracle.fit_GRN_for_simulation(alpha=10, use_cluster_specific_TFdict=True)

#### Calculate future gene expression after perturbation

Here we simulate SOx9 KO; i.e. we predict whap happens if Sox9 gene expression changed into 0.

In [None]:
# Enter perturbation conditions to simulate signal propagation after the perturbation.
oracle.simulate_shift(perturb_condition={goi: 0.0},
                      n_propagation=3)

#### Calculate transition probability between cells
- The steps above simulated global future gene expression shift after perturbation. This prediction is based on iterative calculations of signal propagation within the GRN.
- The next step is to calculate the probability of cell state transitions based on the simulation data. You can use the transition probabilities between cells to predict how cells will change after a perturbation.
- This transition probability will be used later.

In [None]:
# Get transition probability
oracle.estimate_transition_prob(n_neighbors=200,
                                knn_random=True, 
                                sampled_fraction=1)

# Calculate embedding 
oracle.calculate_embedding_shift(sigma_corr = 0.05)

### Visualization

#### Quiver plot: show the direction of cell transition at single cell resolution

**Caution: it is very important to find optimal `scale` parameter**
- We need to adjust the `scale` parameter. Please seek to find the optimal `scale` parameter that provides good visualization.

- If you don't see any vector, you can try the smaller scale parameter to magnify vector length. However, if you see large vectors in the right panel, which is a randomized simulation, it means that the scale parameters are too small.

In [None]:
fig, ax = plt.subplots(1, 2,  figsize=[15, 7])

scale = 15
# Show quiver plot
oracle.plot_quiver(scale=scale, ax=ax[0])
ax[0].set_title(f"Perturbation simulation results: {goi} KO")

# Show quiver plot that was calculated with randomized GRN.
oracle.plot_quiver_random(scale=scale, ax=ax[1])
ax[1].set_title(f"Perturbation simulation with randomized GRNs")

plt.show()

#### Vector field graph

We can visualize simulation result as a vector field graph. Single cell transition vectors are grouped by grid point.

#### Find parameters for n_grid and min_mass

`n_grid`: number of grid points
`min_mass`: threshold value for the cell density. The appropriate values for these parameters depends on the data. Finding appropriate values is done as follows:

In [None]:
# n_grid = 40 is a good point to start with.
n_grid = 40
oracle.calculate_p_mass(smooth=0.8, n_grid=n_grid, n_neighbors=200)

Run `oracle.suggest_mass_thresholds()` to find appropriate min_mass parameter. It will give you some examples.

In [None]:
# Search for best min_mass.
oracle.suggest_mass_thresholds(n_suggestion=32)

In [None]:
min_mass = 1.9e+02
oracle.calculate_mass_filter(min_mass=min_mass, plot=True)

#### Plot vector fields

- Again, we need to adjust the scale parameter. Please seek to find the optimal scale parameter that provides good visualization.

- If you don't see any vector, you can try the smaller scale parameter to magnify vector length. However, if you see large vectors in the right panel, which is a randomized simulation, it means that the scale parameters are too small.

In [None]:
fig, ax = plt.subplots(1, 2,  figsize=[15, 7])

scale_simulation =8
# Show quiver plot
oracle.plot_simulation_flow_on_grid(scale=scale_simulation, ax=ax[0])
ax[0].set_title(f"Perturbation simulation results: {goi} KO")

# Show quiver plot that was calculated with randomized GRN.
oracle.plot_simulation_flow_random_on_grid(scale=scale_simulation, ax=ax[1])
ax[1].set_title(f"Perturbation simulation with randomized GRNs")

plt.show()

In [None]:
# Plot vector field with cell cluster 
fig, ax = plt.subplots(figsize=[8, 8])

oracle.plot_cluster_whole(ax=ax, s=15)
oracle.plot_simulation_flow_on_grid(scale=scale_simulation, ax=ax, show_background=False)

In [None]:
goi= "Sox9"
sc.pl.umap(oracle.adata, color=[goi, oracle.cluster_column_name],
                 layer="imputed_count", use_raw=False, cmap="viridis")

### Compare simulation vector with development vectors

As shown above, we can use celloracle's simulation to infer how TF perturbations affect cell identity. The simulation results are provided in the form of a vector field map.

To interpret the results, it is necessary to take into account the direction of natural differentiation. We will compare the simulated perturbation vectors with the development vector. By comparing them, we can intuitively understand how TF is involved in cell fate determination during development. This perspective is also important for the estimation of experimental perturbation results

Here, we show an example to calculate the vector field of development using pseudotime gradient. In short, the process is as follows.

- Transfer pseudotime data into n x n grid point.

- Calculate the 2D gradient of pseudotime to get vector field

- Compare in silico TF perturbation vector field with development vector field by calculating inner product between these two vectors.

Also, there are many other options to get vector field of development flow from scRNA-seq data, and you can select another option. For example, RNA velocity analysis is a good way to estimate the direction of cell differentiation. Choose the method that best suits the data.

#### Pseudotime data


In the analysis below, we need to use pseudotime data. Pseudotime data is included in the demo data. If you try to analyze your scRNA-seq data, please calculate pseudotime before starting this analysis.

In [None]:
# Visualize pseudotime
fig, ax = plt.subplots(figsize=[6,6])

sc.pl.embedding(adata=oracle.adata, basis=oracle.embedding_name, ax=ax, cmap="rainbow",
                color=["pseudotime"])

#### Make `gradient_calculator` object

In [None]:
from celloracle.applications import Gradient_calculator

# Instantiate gradient calculator object
gradient = Gradient_calculator(oracle_object=oracle, pseudotime_key="pseudotime")

In [None]:
gradient.calculate_p_mass(smooth=0.8, n_grid=n_grid, n_neighbors=200)
gradient.calculate_mass_filter(min_mass=min_mass, plot=True)

#### Transfer pseudotime values to the grid points

Next we will transfer pseudotime data into grid points. For this calculation we can chose two methods:
- knn: k-nearest neighbour regressor. You need to set the number of neighbours. Adjust `n_knn` searching for best results
`gradient.transfer_data_into_grid(args={"method": "knn", "n_knn":50})`
- polynomial: polynomial regression using x-axis and y-axis of dimensionality reduction space. In general this method will be more robust. Use it if k-NN does not work. `n_poly` is the number of degree for the polynomial regression model. To find the appropriate `n_poly`: 
`gradient.transfer_data_into_grid(args={"method": "polynomial", "n_poly":3})`



In [None]:
gradient.transfer_data_into_grid(args={"method": "polynomial", "n_poly":3}, plot=True)

#### Calculate gradient vectors

Calculate 2D vector map that represents the gradient of pseudotime. After the gradient calculation, the length of the vector will be normalized automatically. Adjust `scale` parameter to adjust vector length.

In [None]:
# Calculate graddient
gradient.calculate_gradient()

# Show results
scale_dev = 20
gradient.visualize_results(scale=scale_dev, s=5)

In [None]:
# Visualize results
fig, ax = plt.subplots(figsize=[6, 6])
gradient.plot_dev_flow_on_grid(scale=scale_dev, ax=ax)

### Calculate inner product between two vectors

We will use the inner product to compare the 2D vector map of perturb-simulation and development quantitatively.

If you are not familiar with Inner product / Dot product, please see https://en.wikipedia.org/wiki/Dot_product

- The inner product represents the similarity between two vectors.

- Using the inner product, we compare the 2D vector field of perturbation simulation and development flow.

- Inner product can be a positive value when two vectors are pointing in the same direction.

- Inner product can be a negative value when two vectors are pointing in the opposite direction.

- The length of vector also affects the absolute value of inner product value.

In summary,

- a **negative inner product** means that perturbation might **block differentiation**.
- a **positive inner product** means that perturbation might **promote differentiation**.

In [None]:
from celloracle.applications import Oracle_development_module

# Make Oracle_development_module to compare two vector field
dev = Oracle_development_module()

# Load development flow
dev.load_differentiation_reference_data(gradient_object=gradient)

# Load simulation result
dev.load_perturb_simulation_data(oracle_object=oracle)


# Calculate inner produc scores
dev.calculate_inner_product()
dev.calculate_digitized_ip(n_bins=10)

## Show results

In [None]:
# Let's visualize the results 
p1 = dev.visualize_development_module_layout_0(s=5, 
                                          scale_for_simulation=scale_simulation,
                                          s_grid=10,
                                          scale_for_pseudotime=scale_dev, 
                                          vm=0.02)

In [None]:
fig, ax = plt.subplots(figsize=[6,6])

dev.plot_inner_product_on_grid(s =12 , vm = .01, cmap = "bwr")