Skip to content

Commit

Permalink
Merge pull request #514 from Ukyeon/plot_doc
Browse files Browse the repository at this point in the history
Add X_PCA as special key in DKM and remove .obsm.["X"]
  • Loading branch information
Xiaojieqiu committed May 22, 2023
2 parents dfa6a81 + 364e4e6 commit beaf71a
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 25 deletions.
9 changes: 4 additions & 5 deletions dynamo/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class DynamoAdataKeyManager:
# special key names frequently used in dynamo
X_LAYER = "X"
PROTEIN_LAYER = "protein"
X_PCA = "X_pca"

def gen_new_layer_key(layer_name, key, sep="_") -> str:
"""utility function for returning a new key name for a specific layer. By convention layer_name should not have the separator as the last character."""
Expand Down Expand Up @@ -106,7 +107,7 @@ def check_if_layer_exist(adata: AnnData, layer: str) -> bool:
def get_available_layer_keys(adata, layers="all", remove_pp_layers=True, include_protein=True):
"""Get the list of available layers' keys. If `layers` is set to all, return a list of all available layers; if `layers` is set to a list, then the intersetion of available layers and `layers` will be returned."""
layer_keys = list(adata.layers.keys())
if layers is None: # layers=adata.uns["pp"]["experiment_layers"], in calc_sz_factor
if layers is None: # layers=adata.uns["pp"]["experiment_layers"], in calc_sz_factor
layers = "X"
if remove_pp_layers:
layer_keys = [i for i in layer_keys if not i.startswith("X_")]
Expand Down Expand Up @@ -143,10 +144,7 @@ def allowed_X_layer_names():
def init_uns_pp_namespace(adata: AnnData):
adata.uns[DynamoAdataKeyManager.UNS_PP_KEY] = {}

def get_excluded_layers(
X_total_layers: bool = False,
splicing_total_layers: bool = False
) -> List:
def get_excluded_layers(X_total_layers: bool = False, splicing_total_layers: bool = False) -> List:
"""Get a list of excluded layers based on the provided arguments.
When splicing_total_layers is False, the function normalize spliced and unspliced RNA separately using each
Expand Down Expand Up @@ -199,6 +197,7 @@ def aggregate_layers_into_total(
layers.extend(["_total_"])
return total_layers, layers


# TODO discuss alias naming convention
DKM = DynamoAdataKeyManager

Expand Down
5 changes: 3 additions & 2 deletions dynamo/prediction/fate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm

from ..configuration import DKM
from ..dynamo_logger import (
LoggerManager,
main_info,
Expand Down Expand Up @@ -173,7 +174,7 @@ def fate(
ndim = adata.uns["umap_fit"]["fit"]._raw_data.shape[1]

if "X" in adata.obsm_keys():
if ndim == adata.obsm["X"].shape[1]: # lift the dimension up again
if ndim == adata.obsm[DKM.X_PCA].shape[1]: # lift the dimension up again
exprs = adata.uns["pca_fit"].inverse_transform(prediction)

if adata.var.use_for_dynamics.sum() == exprs.shape[1]:
Expand Down Expand Up @@ -211,7 +212,7 @@ def _fate(
interpolation_num: int = 250,
average: bool = True,
sampling: str = "arc_length",
cores:int = 1,
cores: int = 1,
) -> Tuple[np.ndarray, np.ndarray]:
"""Predict the historical and future cell transcriptomic states over arbitrary time scales by integrating vector
field functions from one or a set of initial cell state(s).
Expand Down
3 changes: 2 additions & 1 deletion dynamo/tools/cell_velocities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sklearn.decomposition import PCA
from sklearn.utils import sparsefuncs

from ..configuration import DKM
from ..dynamo_logger import LoggerManager, main_info, main_warning
from ..utils import areinstance
from .connectivity import _gen_neighbor_keys, adj_to_knn, check_and_recompute_neighbors
Expand Down Expand Up @@ -530,7 +531,7 @@ def cell_velocities(
adata, pca_fit, X_pca = pca(adata, CM, n_pca_components, "X", return_all=True)
adata.uns["pca_fit"] = pca_fit

X_pca, pca_fit = adata.obsm["X"], adata.uns["pca_fit"]
X_pca, pca_fit = adata.obsm[DKM.X_PCA], adata.uns["pca_fit"]
V = adata[:, adata.var.use_for_dynamics.values].layers[vkey] if vkey in adata.layers.keys() else None
CM, V = CM.A if sp.issparse(CM) else CM, V.A if sp.issparse(V) else V
V[np.isnan(V)] = 0
Expand Down
3 changes: 0 additions & 3 deletions dynamo/tools/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,6 @@ def dynamics(
del adata.layers[i]
main_info("making adata smooth...", indent_level=2)

if filter_gene_mode.lower() == "final" and "X_pca" in adata.obsm.keys():
adata.obsm["X"] = adata.obsm["X_pca"]

if group is not None and group in adata.obs.columns:
moments(adata, genes=valid_bools, group=group)
else:
Expand Down
3 changes: 2 additions & 1 deletion dynamo/tools/metric_velocity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm

from ..configuration import DKM
from ..dynamo_logger import LoggerManager, main_critical, main_info, main_warning
from .connectivity import (
adj_to_knn,
Expand Down Expand Up @@ -114,7 +115,7 @@ def cell_wise_confidence(
)

n_neigh = n_neigh[0] if type(n_neigh) == np.ndarray else n_neigh
n_pca_components = adata.obsm["X"].shape[1]
n_pca_components = adata.obsm[DKM.X_PCA].shape[1]

finite_inds = get_finite_inds(V, 0)
X, V = X[:, finite_inds], V[:, finite_inds]
Expand Down
26 changes: 14 additions & 12 deletions dynamo/tools/moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,13 @@ def moments(
if X_data is not None:
X = X_data
else:
if "X" not in adata.obsm.keys():
if DKM.X_PCA not in adata.obsm.keys():
if not any([i.startswith("X_") for i in adata.layers.keys()]):
from ..preprocessing.deprecated import recipe_monocle
from ..preprocessing import Preprocessor

genes_to_use = adata.var_names[genes] if genes.dtype == "bool" else genes
recipe_monocle(
adata,
genes_to_use=genes_to_use,
num_dim=n_pca_components,
)
adata.obsm["X"] = adata.obsm["X_pca"]
preprocessor = Preprocessor(force_gene_list=genes_to_use)
preprocessor.preprocess_adata(adata, recipe="monocle")
else:
CM = adata.X if genes is None else adata[:, genes].X
cm_genesums = CM.sum(axis=0)
Expand All @@ -125,7 +121,7 @@ def moments(

adata.uns["explained_variance_ratio_"] = fit.explained_variance_ratio_[1:]

X = adata.obsm["X"][:, :n_pca_components]
X = adata.obsm[DKM.X_PCA][:, :n_pca_components]

with warnings.catch_warnings():
warnings.simplefilter("ignore")
Expand Down Expand Up @@ -187,7 +183,9 @@ def moments(
layer_x = adata.layers[layer].copy()
matched_x_group_indices = np.where([layer in x for x in [only_splicing, only_labeling, splicing_and_labeling]])
if len(matched_x_group_indices[0]) == 0:
logger.warning(f"layer {layer} is not in any of the {only_splicing, only_labeling, splicing_and_labeling} groups, skipping...")
logger.warning(
f"layer {layer} is not in any of the {only_splicing, only_labeling, splicing_and_labeling} groups, skipping..."
)
continue
layer_x_group = matched_x_group_indices[0][0]
layer_x = inverse_norm(adata, layer_x)
Expand All @@ -199,9 +197,13 @@ def moments(
else (conn.dot(layer_x), conn)
)
for layer2 in layers[i:]:
matched_y_group_indices = np.where([layer2 in x for x in [only_splicing, only_labeling, splicing_and_labeling]])
matched_y_group_indices = np.where(
[layer2 in x for x in [only_splicing, only_labeling, splicing_and_labeling]]
)
if len(matched_y_group_indices[0]) == 0:
logger.warning(f"layer {layer2} is not in any of the {only_splicing, only_labeling, splicing_and_labeling} groups, skipping...")
logger.warning(
f"layer {layer2} is not in any of the {only_splicing, only_labeling, splicing_and_labeling} groups, skipping..."
)
continue
layer_y = adata.layers[layer2].copy()

Expand Down
4 changes: 3 additions & 1 deletion dynamo/tools/velocyto_scvelo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import pandas as pd
from scipy.sparse import csr_matrix

from ..configuration import DKM


def vlm_to_adata(
vlm, n_comps: int = 30, basis: str = "umap", trans_mats: Optional[dict] = None, cells_ixs: List[int] = None
Expand Down Expand Up @@ -88,7 +90,7 @@ def vlm_to_adata(

# set obsm
obsm = {}
obsm["X"] = vlm.pcs[:, : min(n_comps, vlm.pcs.shape[1])]
obsm[DKM.X_PCA] = vlm.pcs[:, : min(n_comps, vlm.pcs.shape[1])]
# set basis and velocity on the basis
if basis is not None:
obsm["X_" + basis] = vlm.ts
Expand Down

0 comments on commit beaf71a

Please sign in to comment.