Skip to content

Commit

Permalink
Merge pull request #666 from Sichao25/neighbors
Browse files Browse the repository at this point in the history
Docstring and type hints for the prediction module
  • Loading branch information
Xiaojieqiu committed Feb 26, 2024
2 parents 1a33abb + 2b5ba65 commit 9a8d22d
Show file tree
Hide file tree
Showing 6 changed files with 569 additions and 173 deletions.
26 changes: 21 additions & 5 deletions dynamo/prediction/fate.py
Expand Up @@ -91,8 +91,7 @@ def fate(
kwargs: Additional parameters that will be passed into the fate function.
Returns:
adata: AnnData object that is updated with the dictionary Fate (includes `t` and `prediction` keys) in uns
attribute.
AnnData object that is updated with the dictionary Fate (includes `t` and `prediction` keys) in uns attribute.
"""

if basis is not None:
Expand Down Expand Up @@ -194,8 +193,14 @@ def _fate(
multiprocessing will be used to parallel the fate prediction.
Returns:
t: The time at which the cell state are predicted.
prediction: Predicted cells states at different time points. Row order corresponds to the element order in t. If init_states corresponds to multiple cells, the expression dynamics over time for each cell is concatenated by rows. That is, the final dimension of prediction is (len(t) * n_cells, n_features). n_cells: number of cells; n_features: number of genes or number of low dimensional embeddings. Of note, if the average is set to be True, the average cell state at each time point is calculated for all cells.
A tuple containing two elements:
t: The time at which the cell state are predicted.
prediction: Predicted cells states at different time points. Row order corresponds to the element order in
t. If init_states corresponds to multiple cells, the expression dynamics over time for each cell is
concatenated by rows. That is, the final dimension of prediction is (len(t) * n_cells, n_features).
n_cells: number of cells; n_features: number of genes or number of low dimensional embeddings.
Of note, if the average is set to be True, the average cell state at each time point is calculated for
all cells.
"""

if sampling == "uniform_indices":
Expand Down Expand Up @@ -260,7 +265,18 @@ def _inverse_transform(
basis: str = "umap",
Qkey: str = "PCs",
) -> Tuple[Union[np.ndarray, List[np.ndarray]], np.ndarray]:
"""Inverse transform the low dimensional vector field prediction back to high dimensional space."""
"""Inverse transform the low dimensional vector field prediction back to high dimensional space.
Args:
adata: AnnData object that contains the reconstructed vector field function in the `uns` attribute.
prediction: Predicted cells states at different time points.
basis: The embedding data to use for predicting cell fate.
Qkey: The key of the PCA loading matrix in `.uns`.
Returns:
The predicted cells states at different time points in high dimensional space and the gene names whose gene
expression will be used for predicting cell fate.
"""
if basis == "pca":
if type(prediction) == list:
exprs = [vector_transformation(cur_pred.T, adata.uns[Qkey]) for cur_pred in prediction]
Expand Down

0 comments on commit 9a8d22d

Please sign in to comment.