Skip to content

Commit

Permalink
update docstr and typing in prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
Sichao25 committed Feb 21, 2024
1 parent 8d8917e commit 2b5ba65
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 142 deletions.
26 changes: 21 additions & 5 deletions dynamo/prediction/fate.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ def fate(
kwargs: Additional parameters that will be passed into the fate function.
Returns:
adata: AnnData object that is updated with the dictionary Fate (includes `t` and `prediction` keys) in uns
attribute.
AnnData object that is updated with the dictionary Fate (includes `t` and `prediction` keys) in uns attribute.
"""

if basis is not None:
Expand Down Expand Up @@ -194,8 +193,14 @@ def _fate(
multiprocessing will be used to parallel the fate prediction.
Returns:
t: The time at which the cell state are predicted.
prediction: Predicted cells states at different time points. Row order corresponds to the element order in t. If init_states corresponds to multiple cells, the expression dynamics over time for each cell is concatenated by rows. That is, the final dimension of prediction is (len(t) * n_cells, n_features). n_cells: number of cells; n_features: number of genes or number of low dimensional embeddings. Of note, if the average is set to be True, the average cell state at each time point is calculated for all cells.
A tuple containing two elements:
t: The time at which the cell state are predicted.
prediction: Predicted cells states at different time points. Row order corresponds to the element order in
t. If init_states corresponds to multiple cells, the expression dynamics over time for each cell is
concatenated by rows. That is, the final dimension of prediction is (len(t) * n_cells, n_features).
n_cells: number of cells; n_features: number of genes or number of low dimensional embeddings.
Of note, if the average is set to be True, the average cell state at each time point is calculated for
all cells.
"""

if sampling == "uniform_indices":
Expand Down Expand Up @@ -260,7 +265,18 @@ def _inverse_transform(
basis: str = "umap",
Qkey: str = "PCs",
) -> Tuple[Union[np.ndarray, List[np.ndarray]], np.ndarray]:
"""Inverse transform the low dimensional vector field prediction back to high dimensional space."""
"""Inverse transform the low dimensional vector field prediction back to high dimensional space.
Args:
adata: AnnData object that contains the reconstructed vector field function in the `uns` attribute.
prediction: Predicted cells states at different time points.
basis: The embedding data to use for predicting cell fate.
Qkey: The key of the PCA loading matrix in `.uns`.
Returns:
The predicted cells states at different time points in high dimensional space and the gene names whose gene
expression will be used for predicting cell fate.
"""
if basis == "pca":
if type(prediction) == list:
exprs = [vector_transformation(cur_pred.T, adata.uns[Qkey]) for cur_pred in prediction]
Expand Down
111 changes: 62 additions & 49 deletions dynamo/prediction/least_action_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@


class LeastActionPath(Trajectory):
"""
A class for computing the Least Action Path for a given function and initial conditions.
"""A class for computing the Least Action Path for a given function and initial conditions.
Args:
X: The initial conditions as a 2D array of shape (n, m), where n is the number of
Expand All @@ -47,8 +46,7 @@ class LeastActionPath(Trajectory):
"""

def __init__(self, X: np.ndarray, vf_func: Callable, D: float = 1, dt: float = 1) -> None:
"""
Initializes the LeastActionPath class instance with the given initial conditions, vector field function,
"""Initializes the LeastActionPath class instance with the given initial conditions, vector field function,
diffusion constant and time step.
Args:
Expand All @@ -66,26 +64,23 @@ def __init__(self, X: np.ndarray, vf_func: Callable, D: float = 1, dt: float = 1
self._action[i] = action(self.X[: i + 1], self.func, self.D, dt)

def get_t(self) -> np.ndarray:
"""
Returns the time points of the trajectory.
"""Returns the time points of the trajectory.
Returns:
ndarray: The time points of the trajectory.
The time points of the trajectory.
"""
return self.t

def get_dt(self) -> float:
"""
Returns the time step of the trajectory.
"""Returns the time step of the trajectory.
Returns:
float: The time step of the trajectory.
The time step of the trajectory.
"""
return np.mean(np.diff(self.t))

def action_t(self, t: Optional[float] = None, **interp_kwargs) -> np.ndarray:
"""
Returns the Least Action Path action values at time t.
"""Returns the Least Action Path action values at time t.
Args:
t: The time point(s) to return the action value(s) for.
Expand All @@ -94,23 +89,30 @@ def action_t(self, t: Optional[float] = None, **interp_kwargs) -> np.ndarray:
**interp_kwargs: Additional keyword arguments to pass to the interp1d function.
Returns:
ndarray: The Least Action Path action value(s).
The Least Action Path action value(s).
"""
if t is None:
return self._action
else:
return interp1d(self.t, self._action, **interp_kwargs)(t)

def mfpt(self, action: Optional[np.ndarray] = None) -> np.ndarray:
"""Eqn. 7 of Epigenetics as a first exit problem."""
"""Eqn. 7 of Epigenetics as a first exit problem.
Args:
action: The action values. If None, uses the action values stored in the _action attribute.
Returns:
The mean first passage time.
"""
action = self._action if action is None else action
return 1 / np.exp(-action)

def optimize_dt(self) -> float:
"""Optimizes the time step of the simulation to minimize the Least Action Path action.
Returns:
float: optimal time step
Optimal time step
"""
dt_0 = self.get_dt()
t_dict = minimize(lambda t: action(self.X, self.func, D=self.D, dt=t), dt_0)
Expand All @@ -122,10 +124,31 @@ def optimize_dt(self) -> float:


class GeneLeastActionPath(GeneTrajectory):
def __init__(self, adata, lap: LeastActionPath = None, X_pca=None, vf_func=None, D=1, dt=1, **kwargs) -> None:
"""
Calculates the least action path trajectory and action for a gene expression dataset.
Inherits from GeneTrajectory class.
"""A class for computing the least action path trajectory and action for a gene expression dataset.
Inherits from GeneTrajectory class.
Attributes:
adata: AnnData object containing the gene expression dataset.
X: Expression data.
to_pca: Transformation matrix from gene expression space to PCA space.
from_pca: Transformation matrix from PCA space to gene expression space.
PCs: Principal components from PCA analysis.
func: Vector field function reconstructed within the PCA space.
D: Diffusivity value.
t: Array of time values.
action: Array of action values.
"""
def __init__(
self,
adata: AnnData,
lap: LeastActionPath = None,
X_pca: Optional[np.ndarray] = None,
vf_func: Optional[Callable] = None,
D: float = 1,
dt: float = 1,
**kwargs,
) -> None:
"""Initializes the GeneLeastActionPath class instance.
Args:
adata: AnnData object containing the gene expression dataset.
Expand All @@ -135,17 +158,6 @@ def __init__(self, adata, lap: LeastActionPath = None, X_pca=None, vf_func=None,
D: Diffusivity value. Defaults to 1.
dt: Time step size. Defaults to 1.
**kwargs: Additional keyword arguments passed to the GeneTrajectory class.
Attributes:
adata: AnnData object containing the gene expression dataset.
X: Expression data.
to_pca: Transformation matrix from gene expression space to PCA space.
from_pca: Transformation matrix from PCA space to gene expression space.
PCs: Principal components from PCA analysis.
func: Vector field function reconstructed within the PCA space.
D: Diffusivity value.
t: Array of time values.
action: Array of action values.
"""
if lap is not None:
self.from_lap(adata, lap, **kwargs)
Expand All @@ -172,29 +184,26 @@ def from_lap(self, adata: AnnData, lap: LeastActionPath, **kwargs):
self.D = lap.D

def get_t(self) -> np.ndarray:
"""
Returns the array of time values.
"""Returns the array of time values.
Returns:
np.ndarray: Array of time values.
Array of time values.
"""
return self.t

def get_dt(self) -> float:
"""
Returns the average time step size.
"""Returns the average time step size.
Returns:
float: Average time step size.
Average time step size.
"""
return np.mean(np.diff(self.t))

def genewise_action(self) -> np.ndarray:
"""
Calculates the genewise action values.
"""Calculates the genewise action values.
Returns:
np.ndarray: Array of genewise action values.
Array of genewise action values.
"""
dt = self.get_dt()
x = (self.X[:-1] + self.X[1:]) * 0.5
Expand All @@ -206,25 +215,28 @@ def genewise_action(self) -> np.ndarray:
return s

def select_genewise_action(self, genes: Union[str, List[str]]) -> np.ndarray:
"""
Returns the genewise action values for the specified genes.
"""Returns the genewise action values for the specified genes.
Args:
genes (Union[str, List[str]]): List of gene names or a single gene name.
Returns:
np.ndarray: Array of genewise action values.
Array of genewise action values.
"""
return super().select_gene(genes, arr=self.action)


def action(path: np.ndarray, vf_func: Callable[[np.ndarray], np.ndarray], D: float = 1, dt: float = 1) -> float:
# centers
"""The action function calculates the action (or functional) of a path in space, given a velocity field function and diffusion coefficient. The path is represented as an array of points in space, and the velocity field is given by vf_func.
The function first calculates the centers of the segments between each point in the path, and then calculates the velocity at each of these centers by taking the average of the velocities at the two neighboring points. The difference between the actual velocity and the velocity field at each center is then calculated and flattened into a one-dimensional array.
"""The action function calculates the action (or functional) of a path in space, given a velocity field function
and diffusion coefficient.
The action is then calculated by taking the dot product of this array with itself, multiplying by a factor of 0.5*dt/D, where dt is the time step used to define the path, and D is the diffusion coefficient.
The path is represented as an array of points in space, and the velocity field is given by vf_func. The function
first calculates the centers of the segments between each point in the path, and then calculates the velocity at
each of these centers by taking the average of the velocities at the two neighboring points. The difference
between the actual velocity and the velocity field at each center is then calculated and flattened into a
one-dimensional array. The action is then calculated by taking the dot product of this array with itself,
multiplying by a factor of 0.5*dt/D, where dt is the time step used to define the path, and D is the diffusion
coefficient.
Args:
path: An array of shape (N, d) containing the coordinates of a path with N points in d dimensions.
Expand Down Expand Up @@ -273,14 +285,15 @@ def action_grad(path: np.ndarray, vf_func: Callable, jac_func: Callable, D: floa
"""Computes the gradient of the action functional with respect to the path.
Args:
path: A 2D array of shape (n+1,d) representing the path, where n is the number of time steps and d is the dimension of the path.
path: A 2D array of shape (n+1,d) representing the path, where n is the number of time steps and d is the
dimension of the path.
vf_func: A function that computes the velocity field vf(x) for a given position x.
jac_func: A function that computes the Jacobian matrix of the velocity field at a given position.
D: The diffusion constant (default is 1).
dt: The time step (default is 1).
Returns:
np.ndarray: The gradient of the action functional with respect to the path, as a 2D array of shape (n,d).
The gradient of the action functional with respect to the path, as a 2D array of shape (n,d).
"""
x = (path[:-1] + path[1:]) * 0.5
v = np.diff(path, axis=0) / dt
Expand Down
37 changes: 20 additions & 17 deletions dynamo/prediction/perturbation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ def KO(
vf_key: A key to the vector field functions in adata.uns.
basis: The basis in which the vector field function is created.
emb_basis: The embedding basis where the perturbed (KO) vector field function will be projected to.
velocity_ko_wt_difference: Whether to use the difference from perturbed (KO) vector field to wildtype vector field in embedding space
instead of raw perturbation (KO) vector field. Using the difference may reveal the perturbation (KO) effects more
clearly.
add_ko_basis_key: The key name for the velocity corresponds to the `basis` name whose associated vector field is perturbed
(KO).
add_embedding_key: The key name for the velocity corresponds to the `embedding` name to which the high dimensional perturbed
(KO) vector field will be projected to.
velocity_ko_wt_difference: Whether to use the difference from perturbed (KO) vector field to wildtype vector
field in embedding space instead of raw perturbation (KO) vector field. Using the difference may reveal the
perturbation (KO) effects more clearly.
add_ko_basis_key: The key name for the velocity corresponds to the `basis` name whose associated vector field
is perturbed (KO).
add_embedding_key: The key name for the velocity corresponds to the `embedding` name to which the high
dimensional perturbed (KO) vector field will be projected to.
store_vf_ko: Whether to store the perturbed (KO) vector field function. By default it is False.
add_vf_ko_key: The key to store the perturbed (KO) vector field function in adata.uns.
return_vector_field_class: Whether to return the perturbed (KO) vector field class. By default it is True.
Expand Down Expand Up @@ -161,30 +161,33 @@ def perturbation(
Args:
adata: an Annodata object.
genes: The gene or list of genes that will be used to perform in-silico perturbation.
expression: The numerical value or list of values that will be used to encode the genetic perturbation. High positive
values indicates up-regulation while low negative value repression.
expression: The numerical value or list of values that will be used to encode the genetic perturbation. High
positive values indicates up-regulation while low negative value repression.
perturb_mode: The mode for perturbing the gene expression vector, either `raw` or `z_score`.
cells: The list of the cell indices that we will perform the perturbation.
zero_perturb_genes_vel: Whether to set the peturbed genes' perturbation velocity vector values to be zero.
pca_key: The key that corresponds to pca embedding. Can also be the actual embedding matrix.
PCs_key: The key that corresponds to PC loading embedding. Can also be the actual loading matrix.
pca_mean_key: The key that corresponds to means values that used for pca projection. Can also be the actual means matrix.
pca_mean_key: The key that corresponds to means values that used for pca projection. Can also be the actual
means matrix.
basis: The key that corresponds to the basis from which the vector field is reconstructed.
jac_key: The key to the jacobian matrix.
X_pca: The pca embedding matrix.
delta_Y: The actual perturbation matrix. This argument enables more customized perturbation schemes.
projection_method: The approach that will be used to project the high dimensional perturbation effect vector to low dimensional
space.
projection_method: The approach that will be used to project the high dimensional perturbation effect vector to
low dimensional space.
pertubation_method: The approach that will be used to calculate the perturbation effect vector after in-silico genetic
perturbation. Can only be one of `"j_delta_x", "j_x_prime", "j_jv", "f_x_prime", "f_x_prime_minus_f_x_0"`
J_jv_delta_t: If pertubation_method is `j_jv`, this will be used to determine the $\\delta x = jv \\delta t_{jv}$
delta_t: This will be used to determine the $\\delta Y = jv \\delta t$
add_delta_Y_key: The key that will be used to store the perturbation effect matrix. Both the pca dimension matrix (stored in
obsm) or the matrix of the original gene expression space (stored in .layers) will use this key. By default
it is None and is set to be `method + '_perturbation'`.
add_delta_Y_key: The key that will be used to store the perturbation effect matrix. Both the pca dimension
matrix (stored in obsm) or the matrix of the original gene expression space (stored in .layers) will use
this key. By default it is None and is set to be `method + '_perturbation'`.
add_transition_key: The dictionary key that will be used for storing the transition matrix in .obsp.
add_velocity_key: The dictionary key that will be used for storing the low dimensional velocity projection matrix in .obsm.
add_embedding_key: The dictionary key that will be used for storing the low dimensional velocity projection matrix in .obsm.
add_velocity_key: The dictionary key that will be used for storing the low dimensional velocity projection
matrix in .obsm.
add_embedding_key: The dictionary key that will be used for storing the low dimensional velocity projection
matrix in .obsm.
Returns:
adata: Returns an updated :class:`~anndata.AnnData` with perturbation effect matrix, projected perturbation vectors
Expand Down

0 comments on commit 2b5ba65

Please sign in to comment.