diff --git a/dynamo/prediction/fate.py b/dynamo/prediction/fate.py index a31c979be..9c2a4deb2 100755 --- a/dynamo/prediction/fate.py +++ b/dynamo/prediction/fate.py @@ -91,8 +91,7 @@ def fate( kwargs: Additional parameters that will be passed into the fate function. Returns: - adata: AnnData object that is updated with the dictionary Fate (includes `t` and `prediction` keys) in uns - attribute. + AnnData object that is updated with the dictionary Fate (includes `t` and `prediction` keys) in uns attribute. """ if basis is not None: @@ -194,8 +193,14 @@ def _fate( multiprocessing will be used to parallel the fate prediction. Returns: - t: The time at which the cell state are predicted. - prediction: Predicted cells states at different time points. Row order corresponds to the element order in t. If init_states corresponds to multiple cells, the expression dynamics over time for each cell is concatenated by rows. That is, the final dimension of prediction is (len(t) * n_cells, n_features). n_cells: number of cells; n_features: number of genes or number of low dimensional embeddings. Of note, if the average is set to be True, the average cell state at each time point is calculated for all cells. + A tuple containing two elements: + t: The time at which the cell state are predicted. + prediction: Predicted cells states at different time points. Row order corresponds to the element order in + t. If init_states corresponds to multiple cells, the expression dynamics over time for each cell is + concatenated by rows. That is, the final dimension of prediction is (len(t) * n_cells, n_features). + n_cells: number of cells; n_features: number of genes or number of low dimensional embeddings. + Of note, if the average is set to be True, the average cell state at each time point is calculated for + all cells. """ if sampling == "uniform_indices": @@ -260,7 +265,18 @@ def _inverse_transform( basis: str = "umap", Qkey: str = "PCs", ) -> Tuple[Union[np.ndarray, List[np.ndarray]], np.ndarray]: - """Inverse transform the low dimensional vector field prediction back to high dimensional space.""" + """Inverse transform the low dimensional vector field prediction back to high dimensional space. + + Args: + adata: AnnData object that contains the reconstructed vector field function in the `uns` attribute. + prediction: Predicted cells states at different time points. + basis: The embedding data to use for predicting cell fate. + Qkey: The key of the PCA loading matrix in `.uns`. + + Returns: + The predicted cells states at different time points in high dimensional space and the gene names whose gene + expression will be used for predicting cell fate. + """ if basis == "pca": if type(prediction) == list: exprs = [vector_transformation(cur_pred.T, adata.uns[Qkey]) for cur_pred in prediction] diff --git a/dynamo/prediction/least_action_path.py b/dynamo/prediction/least_action_path.py index b6ad3b0e6..430f3c649 100644 --- a/dynamo/prediction/least_action_path.py +++ b/dynamo/prediction/least_action_path.py @@ -20,8 +20,7 @@ class LeastActionPath(Trajectory): - """ - A class for computing the Least Action Path for a given function and initial conditions. + """A class for computing the Least Action Path for a given function and initial conditions. Args: X: The initial conditions as a 2D array of shape (n, m), where n is the number of @@ -47,8 +46,7 @@ class LeastActionPath(Trajectory): """ def __init__(self, X: np.ndarray, vf_func: Callable, D: float = 1, dt: float = 1) -> None: - """ - Initializes the LeastActionPath class instance with the given initial conditions, vector field function, + """Initializes the LeastActionPath class instance with the given initial conditions, vector field function, diffusion constant and time step. Args: @@ -66,26 +64,23 @@ def __init__(self, X: np.ndarray, vf_func: Callable, D: float = 1, dt: float = 1 self._action[i] = action(self.X[: i + 1], self.func, self.D, dt) def get_t(self) -> np.ndarray: - """ - Returns the time points of the trajectory. + """Returns the time points of the trajectory. Returns: - ndarray: The time points of the trajectory. + The time points of the trajectory. """ return self.t def get_dt(self) -> float: - """ - Returns the time step of the trajectory. + """Returns the time step of the trajectory. Returns: - float: The time step of the trajectory. + The time step of the trajectory. """ return np.mean(np.diff(self.t)) def action_t(self, t: Optional[float] = None, **interp_kwargs) -> np.ndarray: - """ - Returns the Least Action Path action values at time t. + """Returns the Least Action Path action values at time t. Args: t: The time point(s) to return the action value(s) for. @@ -94,7 +89,7 @@ def action_t(self, t: Optional[float] = None, **interp_kwargs) -> np.ndarray: **interp_kwargs: Additional keyword arguments to pass to the interp1d function. Returns: - ndarray: The Least Action Path action value(s). + The Least Action Path action value(s). """ if t is None: return self._action @@ -102,7 +97,14 @@ def action_t(self, t: Optional[float] = None, **interp_kwargs) -> np.ndarray: return interp1d(self.t, self._action, **interp_kwargs)(t) def mfpt(self, action: Optional[np.ndarray] = None) -> np.ndarray: - """Eqn. 7 of Epigenetics as a first exit problem.""" + """Eqn. 7 of Epigenetics as a first exit problem. + + Args: + action: The action values. If None, uses the action values stored in the _action attribute. + + Returns: + The mean first passage time. + """ action = self._action if action is None else action return 1 / np.exp(-action) @@ -110,7 +112,7 @@ def optimize_dt(self) -> float: """Optimizes the time step of the simulation to minimize the Least Action Path action. Returns: - float: optimal time step + Optimal time step """ dt_0 = self.get_dt() t_dict = minimize(lambda t: action(self.X, self.func, D=self.D, dt=t), dt_0) @@ -122,10 +124,31 @@ def optimize_dt(self) -> float: class GeneLeastActionPath(GeneTrajectory): - def __init__(self, adata, lap: LeastActionPath = None, X_pca=None, vf_func=None, D=1, dt=1, **kwargs) -> None: - """ - Calculates the least action path trajectory and action for a gene expression dataset. - Inherits from GeneTrajectory class. + """A class for computing the least action path trajectory and action for a gene expression dataset. + Inherits from GeneTrajectory class. + + Attributes: + adata: AnnData object containing the gene expression dataset. + X: Expression data. + to_pca: Transformation matrix from gene expression space to PCA space. + from_pca: Transformation matrix from PCA space to gene expression space. + PCs: Principal components from PCA analysis. + func: Vector field function reconstructed within the PCA space. + D: Diffusivity value. + t: Array of time values. + action: Array of action values. + """ + def __init__( + self, + adata: AnnData, + lap: LeastActionPath = None, + X_pca: Optional[np.ndarray] = None, + vf_func: Optional[Callable] = None, + D: float = 1, + dt: float = 1, + **kwargs, + ) -> None: + """Initializes the GeneLeastActionPath class instance. Args: adata: AnnData object containing the gene expression dataset. @@ -135,17 +158,6 @@ def __init__(self, adata, lap: LeastActionPath = None, X_pca=None, vf_func=None, D: Diffusivity value. Defaults to 1. dt: Time step size. Defaults to 1. **kwargs: Additional keyword arguments passed to the GeneTrajectory class. - - Attributes: - adata: AnnData object containing the gene expression dataset. - X: Expression data. - to_pca: Transformation matrix from gene expression space to PCA space. - from_pca: Transformation matrix from PCA space to gene expression space. - PCs: Principal components from PCA analysis. - func: Vector field function reconstructed within the PCA space. - D: Diffusivity value. - t: Array of time values. - action: Array of action values. """ if lap is not None: self.from_lap(adata, lap, **kwargs) @@ -172,29 +184,26 @@ def from_lap(self, adata: AnnData, lap: LeastActionPath, **kwargs): self.D = lap.D def get_t(self) -> np.ndarray: - """ - Returns the array of time values. + """Returns the array of time values. Returns: - np.ndarray: Array of time values. + Array of time values. """ return self.t def get_dt(self) -> float: - """ - Returns the average time step size. + """Returns the average time step size. Returns: - float: Average time step size. + Average time step size. """ return np.mean(np.diff(self.t)) def genewise_action(self) -> np.ndarray: - """ - Calculates the genewise action values. + """Calculates the genewise action values. Returns: - np.ndarray: Array of genewise action values. + Array of genewise action values. """ dt = self.get_dt() x = (self.X[:-1] + self.X[1:]) * 0.5 @@ -206,25 +215,28 @@ def genewise_action(self) -> np.ndarray: return s def select_genewise_action(self, genes: Union[str, List[str]]) -> np.ndarray: - """ - Returns the genewise action values for the specified genes. + """Returns the genewise action values for the specified genes. Args: genes (Union[str, List[str]]): List of gene names or a single gene name. Returns: - np.ndarray: Array of genewise action values. + Array of genewise action values. """ return super().select_gene(genes, arr=self.action) def action(path: np.ndarray, vf_func: Callable[[np.ndarray], np.ndarray], D: float = 1, dt: float = 1) -> float: - # centers - """The action function calculates the action (or functional) of a path in space, given a velocity field function and diffusion coefficient. The path is represented as an array of points in space, and the velocity field is given by vf_func. + """The action function calculates the action (or functional) of a path in space, given a velocity field function + and diffusion coefficient. - The function first calculates the centers of the segments between each point in the path, and then calculates the velocity at each of these centers by taking the average of the velocities at the two neighboring points. The difference between the actual velocity and the velocity field at each center is then calculated and flattened into a one-dimensional array. - - The action is then calculated by taking the dot product of this array with itself, multiplying by a factor of 0.5*dt/D, where dt is the time step used to define the path, and D is the diffusion coefficient. + The path is represented as an array of points in space, and the velocity field is given by vf_func. The function + first calculates the centers of the segments between each point in the path, and then calculates the velocity at + each of these centers by taking the average of the velocities at the two neighboring points. The difference + between the actual velocity and the velocity field at each center is then calculated and flattened into a + one-dimensional array. The action is then calculated by taking the dot product of this array with itself, + multiplying by a factor of 0.5*dt/D, where dt is the time step used to define the path, and D is the diffusion + coefficient. Args: path: An array of shape (N, d) containing the coordinates of a path with N points in d dimensions. @@ -244,24 +256,44 @@ def action(path: np.ndarray, vf_func: Callable[[np.ndarray], np.ndarray], D: flo return s -def action_aux(path_flatten, vf_func, dim, start=None, end=None, **kwargs): +def action_aux( + path_flatten: np.ndarray, + vf_func: Callable, + dim: int, + start: Optional[np.ndarray] = None, + end: Optional[np.ndarray] = None, + **kwargs, +) -> float: + """Auxiliary function for computing the action of a path. + + Args: + path_flatten: A 1D array representing the flattened path. + vf_func: A function that computes the velocity field vf(x) for a given position x. + dim: The dimension of the path. + start: The starting point of the path. + end: The end point of the path. + **kwargs: Additional keyword arguments to pass to the action function. + + Returns: + The action of the path. + """ path = reshape_path(path_flatten, dim, start=start, end=end) return action(path, vf_func, **kwargs) def action_grad(path: np.ndarray, vf_func: Callable, jac_func: Callable, D: float = 1.0, dt: float = 1.0) -> np.ndarray: - """ - Computes the gradient of the action functional with respect to the path. + """Computes the gradient of the action functional with respect to the path. Args: - path: A 2D array of shape (n+1,d) representing the path, where n is the number of time steps and d is the dimension of the path. + path: A 2D array of shape (n+1,d) representing the path, where n is the number of time steps and d is the + dimension of the path. vf_func: A function that computes the velocity field vf(x) for a given position x. jac_func: A function that computes the Jacobian matrix of the velocity field at a given position. D: The diffusion constant (default is 1). dt: The time step (default is 1). Returns: - np.ndarray: The gradient of the action functional with respect to the path, as a 2D array of shape (n,d). + The gradient of the action functional with respect to the path, as a 2D array of shape (n,d). """ x = (path[:-1] + path[1:]) * 0.5 v = np.diff(path, axis=0) / dt @@ -275,12 +307,50 @@ def action_grad(path: np.ndarray, vf_func: Callable, jac_func: Callable, D: floa return grad -def action_grad_aux(path_flatten, vf_func, jac_func, dim, start=None, end=None, **kwargs): +def action_grad_aux( + path_flatten: np.ndarray, + vf_func: Callable, + jac_func: Callable, + dim: int, + start: Optional[np.ndarray] = None, + end: Optional[np.ndarray] = None, + **kwargs, +) -> np.ndarray: + """Auxiliary function for computing the gradient of the action functional with respect to the path. + + Args: + path_flatten: A 1D array representing the flattened path. + vf_func: A function that computes the velocity field vf(x) for a given position x. + jac_func: A function that computes the Jacobian matrix of the velocity field at a given position. + dim: The dimension of the path. + start: The starting point of the path. + end: The end point of the path. + **kwargs: Additional keyword arguments to pass to the action_grad function. + + Returns: + The gradient of the action functional with respect to the path. + """ path = reshape_path(path_flatten, dim, start=start, end=end) return action_grad(path, vf_func, jac_func, **kwargs).flatten() -def reshape_path(path_flatten, dim, start=None, end=None): +def reshape_path( + path_flatten: np.ndarray, + dim: int, + start: Optional[np.ndarray] = None, + end: Optional[np.ndarray] = None, +) -> np.ndarray: + """Reshapes a flattened path into a 2D array. + + Args: + path_flatten: A 1D array representing the flattened path. + dim: The dimension of the path. + start: The starting point of the path. + end: The end point of the path. + + Returns: + A 2D array representing the path. + """ path = path_flatten.reshape(int(len(path_flatten) / dim), dim) if start is not None: path = np.vstack((start, path)) @@ -381,7 +451,36 @@ def least_action_path( return path, dt, action_opt -def minimize_lap_time(path_0, t0, t_min, vf_func, jac_func, D=1, num_t=20, elbow_method="hessian", hes_tol=3): +def minimize_lap_time( + path_0: np.ndarray, + t0: float, + t_min: float, + vf_func: Callable, + jac_func: Callable, + D: Union[float, int, np.ndarray] = 1, + num_t: int = 20, + elbow_method: str = "hessian", + hes_tol=3, +) -> Tuple[int, List[np.ndarray], np.ndarray, np.ndarray]: + """Minimize the least action path time. + + Args: + path_0: The initial path. + t0: The initial time to start the minimization. + t_min: The minimum time to consider. + vf_func: The vector field function. + jac_func: The Jacobian function. + D: The diffusion constant or matrix. + num_t: The number of time steps. + elbow_method: The method to use to find the elbow in the action vs time plot. + hes_tol: The tolerance to use for the elbow method. + + Returns: + A tuple containing the following elements: + - i_elbow: The index of the elbow in the action vs time plot. + - laps: A list of the least action paths for each time step. + - A: An array of action values for each time step. + """ T = np.linspace(t_min, t0, num_t) A = np.zeros(num_t) opt_T = np.zeros(num_t) @@ -398,7 +497,25 @@ def minimize_lap_time(path_0, t0, t_min, vf_func, jac_func, D=1, num_t=20, elbow return i_elbow, laps, A, opt_T -def get_init_path(G, start, end, coords, interpolation_num=20): +def get_init_path( + G: nx.Graph, + start: np.ndarray, + end: np.ndarray, + coords: np.ndarray, + interpolation_num: int = 20, +) -> np.ndarray: + """Get the initial path for the least action path calculation. + + Args: + G: A networkx graph representing the cell state space. + start: The starting point of the path. + end: The end point of the path. + coords: The coordinates of the cell states. + interpolation_num: The number of points to use in the initial path. + + Returns: + The initial path for the least action path calculation. + """ source_ind = nearest_neighbors(start, coords, k=1)[0][0] target_ind = nearest_neighbors(end, coords, k=1)[0][0] diff --git a/dynamo/prediction/perturbation.py b/dynamo/prediction/perturbation.py index 8f2923413..ab26f6571 100644 --- a/dynamo/prediction/perturbation.py +++ b/dynamo/prediction/perturbation.py @@ -43,13 +43,13 @@ def KO( vf_key: A key to the vector field functions in adata.uns. basis: The basis in which the vector field function is created. emb_basis: The embedding basis where the perturbed (KO) vector field function will be projected to. - velocity_ko_wt_difference: Whether to use the difference from perturbed (KO) vector field to wildtype vector field in embedding space - instead of raw perturbation (KO) vector field. Using the difference may reveal the perturbation (KO) effects more - clearly. - add_ko_basis_key: The key name for the velocity corresponds to the `basis` name whose associated vector field is perturbed - (KO). - add_embedding_key: The key name for the velocity corresponds to the `embedding` name to which the high dimensional perturbed - (KO) vector field will be projected to. + velocity_ko_wt_difference: Whether to use the difference from perturbed (KO) vector field to wildtype vector + field in embedding space instead of raw perturbation (KO) vector field. Using the difference may reveal the + perturbation (KO) effects more clearly. + add_ko_basis_key: The key name for the velocity corresponds to the `basis` name whose associated vector field + is perturbed (KO). + add_embedding_key: The key name for the velocity corresponds to the `embedding` name to which the high + dimensional perturbed (KO) vector field will be projected to. store_vf_ko: Whether to store the perturbed (KO) vector field function. By default it is False. add_vf_ko_key: The key to store the perturbed (KO) vector field function in adata.uns. return_vector_field_class: Whether to return the perturbed (KO) vector field class. By default it is True. @@ -161,30 +161,33 @@ def perturbation( Args: adata: an Annodata object. genes: The gene or list of genes that will be used to perform in-silico perturbation. - expression: The numerical value or list of values that will be used to encode the genetic perturbation. High positive - values indicates up-regulation while low negative value repression. + expression: The numerical value or list of values that will be used to encode the genetic perturbation. High + positive values indicates up-regulation while low negative value repression. perturb_mode: The mode for perturbing the gene expression vector, either `raw` or `z_score`. cells: The list of the cell indices that we will perform the perturbation. zero_perturb_genes_vel: Whether to set the peturbed genes' perturbation velocity vector values to be zero. pca_key: The key that corresponds to pca embedding. Can also be the actual embedding matrix. PCs_key: The key that corresponds to PC loading embedding. Can also be the actual loading matrix. - pca_mean_key: The key that corresponds to means values that used for pca projection. Can also be the actual means matrix. + pca_mean_key: The key that corresponds to means values that used for pca projection. Can also be the actual + means matrix. basis: The key that corresponds to the basis from which the vector field is reconstructed. jac_key: The key to the jacobian matrix. X_pca: The pca embedding matrix. delta_Y: The actual perturbation matrix. This argument enables more customized perturbation schemes. - projection_method: The approach that will be used to project the high dimensional perturbation effect vector to low dimensional - space. + projection_method: The approach that will be used to project the high dimensional perturbation effect vector to + low dimensional space. pertubation_method: The approach that will be used to calculate the perturbation effect vector after in-silico genetic perturbation. Can only be one of `"j_delta_x", "j_x_prime", "j_jv", "f_x_prime", "f_x_prime_minus_f_x_0"` J_jv_delta_t: If pertubation_method is `j_jv`, this will be used to determine the $\\delta x = jv \\delta t_{jv}$ delta_t: This will be used to determine the $\\delta Y = jv \\delta t$ - add_delta_Y_key: The key that will be used to store the perturbation effect matrix. Both the pca dimension matrix (stored in - obsm) or the matrix of the original gene expression space (stored in .layers) will use this key. By default - it is None and is set to be `method + '_perturbation'`. + add_delta_Y_key: The key that will be used to store the perturbation effect matrix. Both the pca dimension + matrix (stored in obsm) or the matrix of the original gene expression space (stored in .layers) will use + this key. By default it is None and is set to be `method + '_perturbation'`. add_transition_key: The dictionary key that will be used for storing the transition matrix in .obsp. - add_velocity_key: The dictionary key that will be used for storing the low dimensional velocity projection matrix in .obsm. - add_embedding_key: The dictionary key that will be used for storing the low dimensional velocity projection matrix in .obsm. + add_velocity_key: The dictionary key that will be used for storing the low dimensional velocity projection + matrix in .obsm. + add_embedding_key: The dictionary key that will be used for storing the low dimensional velocity projection + matrix in .obsm. Returns: adata: Returns an updated :class:`~anndata.AnnData` with perturbation effect matrix, projected perturbation vectors diff --git a/dynamo/prediction/trajectory.py b/dynamo/prediction/trajectory.py index 8b2f4a2a0..1b3cbadf3 100644 --- a/dynamo/prediction/trajectory.py +++ b/dynamo/prediction/trajectory.py @@ -1,7 +1,8 @@ -from typing import Callable, List, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import scipy +from anndata import AnnData from scipy.interpolate import interp1d from ..dynamo_logger import LoggerManager @@ -13,9 +14,9 @@ class Trajectory: + """Base class for handling trajectory interpolation, resampling, etc.""" def __init__(self, X: np.ndarray, t: Union[None, np.ndarray] = None, sort: bool = True) -> None: - """ - Base class for handling trajectory interpolation, resampling, etc. + """Initializes a Trajectory object. Args: X: trajectory positions, shape (n_points, n_dimensions) @@ -37,8 +38,7 @@ def __len__(self) -> int: return self.X.shape[0] def set_time(self, t: np.ndarray, sort: bool = True) -> None: - """ - Set the time stamps for the trajectory. Sorts the time stamps if requested. + """Set the time stamps for the trajectory. Sorts the time stamps if requested. Args: t: trajectory times, shape (n_points,) @@ -52,8 +52,7 @@ def set_time(self, t: np.ndarray, sort: bool = True) -> None: self.t = t def dim(self) -> int: - """ - Returns the number of dimensions in the trajectory. + """Returns the number of dimensions in the trajectory. Returns: number of dimensions in the trajectory @@ -61,8 +60,7 @@ def dim(self) -> int: return self.X.shape[1] def calc_tangent(self, normalize: bool = True): - """ - Calculate the tangent vectors of the trajectory. + """Calculate the tangent vectors of the trajectory. Args: normalize: whether to normalize the tangent vectors. Defaults to True. @@ -76,8 +74,7 @@ def calc_tangent(self, normalize: bool = True): return tvec def calc_arclength(self) -> float: - """ - Calculate the arc length of the trajectory. + """Calculate the arc length of the trajectory. Returns: arc length of the trajectory @@ -87,8 +84,7 @@ def calc_arclength(self) -> float: return np.sum(norms) def calc_curvature(self) -> np.ndarray: - """ - Calculate the curvature of the trajectory. + """Calculate the curvature of the trajectory. Returns: curvature of the trajectory, shape (n_points,) @@ -101,8 +97,7 @@ def calc_curvature(self) -> np.ndarray: return kappa def resample(self, n_points: int, tol: float = 1e-4, inplace: bool = True) -> Tuple[np.ndarray, np.ndarray]: - """ - Resample the curve with the specified number of points. + """Resample the curve with the specified number of points. Args: n_points: An integer specifying the number of points in the resampled curve. @@ -147,7 +142,7 @@ def archlength_sampling( sol: scipy.integrate._ivp.common.OdeSolution, interpolation_num: int, integration_direction: str, - ): + ) -> None: """Sample the curve using archlength sampling. Args: @@ -184,7 +179,7 @@ def logspace_sampling( sol: scipy.integrate._ivp.common.OdeSolution, interpolation_num: int, integration_direction: str, - ): + ) -> None: """Sample the curve using logspace sampling. Args: @@ -256,8 +251,7 @@ def interp_t(self, num: int = 100) -> np.ndarray: return np.linspace(self.t[0], self.t[-1], num=num) def interp_X(self, num: int = 100, **interp_kwargs) -> np.ndarray: - """ - Interpolates the curve at `num` equally spaced points in `t`. + """Interpolates the curve at `num` equally spaced points in `t`. Args: num: The number of points to interpolate the curve at. @@ -306,9 +300,9 @@ def calc_msd(self, decomp_dim: bool = True, ref: int = 0) -> Union[float, np.nda class VectorFieldTrajectory(Trajectory): + """Class for handling trajectory data with a differentiable vector field.""" def __init__(self, X: np.ndarray, t: np.ndarray, vecfld: DifferentiableVectorField) -> None: - """ - Initializes a VectorFieldTrajectory object. + """Initializes a VectorFieldTrajectory object. Args: X: The trajectory data as a numpy array of shape (n, d). @@ -321,8 +315,7 @@ def __init__(self, X: np.ndarray, t: np.ndarray, vecfld: DifferentiableVectorFie self.Js = None def get_velocities(self) -> np.ndarray: - """ - Calculates and returns the velocities along the trajectory. + """Calculates and returns the velocities along the trajectory. Returns: The velocity data as a numpy array of shape (n, d). @@ -331,9 +324,8 @@ def get_velocities(self) -> np.ndarray: self.data["velocity"] = self.vecfld.func(self.X) return self.data["velocity"] - def get_jacobians(self, method=None) -> np.ndarray: - """ - Calculates and returns the Jacobians of the vector field along the trajectory. + def get_jacobians(self, method: Optional[str] = None) -> np.ndarray: + """Calculates and returns the Jacobians of the vector field along the trajectory. Args: method: The method used to compute the Jacobians. @@ -346,9 +338,8 @@ def get_jacobians(self, method=None) -> np.ndarray: self.Js = fjac(self.X) return self.Js - def get_accelerations(self, method=None, **kwargs) -> np.ndarray: - """ - Calculates and returns the accelerations along the trajectory. + def get_accelerations(self, method: Optional[str] = None, **kwargs) -> np.ndarray: + """Calculates and returns the accelerations along the trajectory. Args: method: The method used to compute the Jacobians. @@ -363,9 +354,8 @@ def get_accelerations(self, method=None, **kwargs) -> np.ndarray: self.data["acceleration"] = self.vecfld.compute_acceleration(self.X, Js=self.Js, **kwargs) return self.data["acceleration"] - def get_curvatures(self, method=None, **kwargs) -> np.ndarray: - """ - Calculates and returns the curvatures along the trajectory. + def get_curvatures(self, method: Optional[str] = None, **kwargs) -> np.ndarray: + """Calculates and returns the curvatures along the trajectory. Args: method: The method used to compute the Jacobians. @@ -380,9 +370,8 @@ def get_curvatures(self, method=None, **kwargs) -> np.ndarray: self.data["curvature"] = self.vecfld.compute_curvature(self.X, Js=self.Js, **kwargs) return self.data["curvature"] - def get_divergences(self, method=None, **kwargs) -> np.ndarray: - """ - Calculates and returns the divergences along the trajectory. + def get_divergences(self, method: Optional[str] = None, **kwargs) -> np.ndarray: + """Calculates and returns the divergences along the trajectory. Args: method: The method used to compute the Jacobians. @@ -424,20 +413,32 @@ def calc_vector_msd(self, key: str, decomp_dim: bool = True, ref: int = 0) -> Un class GeneTrajectory(Trajectory): + """Class for handling gene expression trajectory data.""" def __init__( self, - adata, - X=None, - t=None, - X_pca=None, - PCs="PCs", - mean="pca_mean", - genes="use_for_pca", - expr_func=None, + adata: AnnData, + X: Optional[np.ndarray] = None, + t: Optional[np.ndarray] = None, + X_pca: Optional[np.ndarray] = None, + PCs: str = "PCs", + mean: str = "pca_mean", + genes: str = "use_for_pca", + expr_func: Optional[Callable] = None, **kwargs, ) -> None: - """ - This class is not fully functional yet. + """Initializes a GeneTrajectory object. + + Args: + adata: Anndata object containing the gene expression data. + X: The gene expression data as a numpy array of shape (n, d). Defaults to None. + t: The time data as a numpy array of shape (n,). Defaults to None. + X_pca: The PCA-transformed gene expression data as a numpy array of shape (n, d). Defaults to None. + PCs: The key in adata.uns to use for the PCA components. Defaults to "PCs". + mean: The key in adata.uns to use for the PCA mean. Defaults to "pca_mean". + genes: The key in adata.var to use for the genes. Defaults to "use_for_pca". + expr_func: A function to transform the PCA-transformed gene expression data back to the original space. + Defaults to None. + **kwargs: Additional keyword arguments to be passed to the superclass initializer. """ self.adata = adata if type(PCs) is str: @@ -460,22 +461,50 @@ def __init__( if X is not None: super().__init__(X, t=t) - def from_pca(self, X_pca, t=None): + def from_pca(self, X_pca: np.ndarray, t: Optional[np.ndarray] = None) -> None: + """Converts PCA-transformed gene expression data to gene expression data. + + Args: + X_pca: The PCA-transformed gene expression data as a numpy array of shape (n, d). + t: The time data as a numpy array of shape (n,). Defaults to None. + """ X = pca_to_expr(X_pca, self.PCs, mean=self.mean, func=self.expr_func) super().__init__(X, t=t) - def to_pca(self, x=None): + def to_pca(self, x: Optional[np.ndarray] = None) -> np.ndarray: + """Converts gene expression data to PCA-transformed gene expression data. + + Args: + x: The gene expression data as a numpy array of shape (n, d). Defaults to None. + + Returns: + The PCA-transformed gene expression data as a numpy array of shape (n, d). + """ if x is None: x = self.X return expr_to_pca(x, self.PCs, mean=self.mean, func=self.expr_func) - def genes_to_mask(self): + def genes_to_mask(self) -> np.ndarray: + """Returns a boolean mask for the genes in the trajectory. + + Returns: + A boolean mask for the genes in the trajectory. + """ mask = np.zeros(self.adata.n_vars, dtype=np.bool_) for g in self.genes: mask[self.adata.var_names == g] = True return mask - def calc_msd(self, save_key="traj_msd", **kwargs): + def calc_msd(self, save_key: str = "traj_msd", **kwargs) -> Union[float, np.ndarray]: + """Calculate the mean squared displacement (MSD) of the gene expression trajectory. + + Args: + save_key: The key to save the MSD data to in adata.var. Defaults to "traj_msd". + **kwargs: Additional keyword arguments to be passed to the superclass method. + + Returns: + The mean squared displacement of the gene expression trajectory. + """ msd = super().calc_msd(**kwargs) LoggerManager.main_logger.info_insert_adata(save_key, "var") @@ -484,12 +513,29 @@ def calc_msd(self, save_key="traj_msd", **kwargs): return msd - def save(self, save_key="gene_trajectory"): + def save(self, save_key: str = "gene_trajectory") -> None: + """Save the gene expression trajectory to adata.var. + + Args: + save_key: The key to save the gene expression trajectory to in adata.var. Defaults to "gene_trajectory". + """ LoggerManager.main_logger.info_insert_adata(save_key, "varm") self.adata.varm[save_key] = np.ones((self.adata.n_vars, self.X.shape[0])) * np.nan self.adata.varm[save_key][self.genes_to_mask(), :] = self.X.T - def select_gene(self, genes, arr=None, axis=None): + def select_gene( + self, genes: Union[np.ndarray, list], arr: Optional[np.ndarray] = None, axis: Optional[int] = None, + ) -> np.ndarray: + """Selects the gene expression data for the specified genes. + + Args: + genes: The genes to select the expression data for. + arr: The array to select the genes from. Defaults to None. + axis: The axis to select the genes along. Defaults to None. + + Returns: + The gene expression data for the specified genes. + """ if arr is None: arr = self.X if arr.ndim == 1: @@ -513,7 +559,19 @@ def select_gene(self, genes, arr=None, axis=None): return np.array(y) -def arclength_sampling_n(X, num, t=None): +def arclength_sampling_n( + X: np.ndarray, num: int, t: Optional[np.ndarray] = None, +) -> Union[Tuple[np.ndarray, float], Tuple[np.ndarray, float, np.ndarray]]: + """Uniformly sample data points on an arc curve that generated from vector field predictions. + + Args: + X: The data points to sample from. + num: The number of points to sample. + t: The time values for the data points. Defaults to None. + + Returns: + The sampled data points and the arc length of the curve. + """ arclen = np.cumsum(np.linalg.norm(np.diff(X, axis=0), axis=1)) arclen = np.hstack((0, arclen)) @@ -526,8 +584,19 @@ def arclength_sampling_n(X, num, t=None): return X_, arclen[-1] -def remove_redundant_points_trajectory(X, tol=1e-4, output_discard=False): - """remove consecutive data points that are too close to each other.""" +def remove_redundant_points_trajectory( + X: np.ndarray, tol: float = 1e-4, output_discard: bool = False, +) -> Union[Tuple[np.ndarray, float], Tuple[np.ndarray, float, np.ndarray]]: + """Remove consecutive data points that are too close to each other. + + Args: + X: The data points to remove redundant points from. + tol: The tolerance for removing redundant points. Defaults to 1e-4. + output_discard: Whether to output the discarded points. Defaults to False. + + Returns: + The data points with redundant points removed and the arc length of the curve. + """ X = np.atleast_2d(X) discard = np.zeros(len(X), dtype=bool) if X.shape[0] > 1: @@ -552,8 +621,18 @@ def remove_redundant_points_trajectory(X, tol=1e-4, output_discard=False): return (X, arclength) -def arclength_sampling(X, step_length, n_steps: int, t=None): - """uniformly sample data points on an arc curve that generated from vector field predictions.""" +def arclength_sampling(X: np.ndarray, step_length: float, n_steps: int, t: Optional[np.ndarray] = None) -> np.ndarray: + """Uniformly sample data points on an arc curve that generated from vector field predictions. + + Args: + X: The data points to sample from. + step_length: The length of each step. + n_steps: The number of steps to sample. + t: The time values for the data points. Defaults to None. + + Returns: + The sampled data points and the arc length of the curve. + """ Y = [] x0 = X[0] T = [] if t is not None else None diff --git a/dynamo/prediction/trajectory_analysis.py b/dynamo/prediction/trajectory_analysis.py index 53e5b25d4..8013a55e6 100644 --- a/dynamo/prediction/trajectory_analysis.py +++ b/dynamo/prediction/trajectory_analysis.py @@ -19,8 +19,7 @@ def calc_mean_exit_time(trajectories: List[Trajectory], in_init_state: Callable, in_sink_state: Callable) -> float: - """ - Calculates the mean exit time (MET) from the initial state to the sink state for a list of trajectories. + """Calculates the mean exit time (MET) from the initial state to the sink state for a list of trajectories. Args: trajectories: A list of trajectories. @@ -52,8 +51,7 @@ def calc_mean_exit_time(trajectories: List[Trajectory], in_init_state: Callable, def calc_mean_first_passage_time( trajectories: List[Trajectory], in_init_state: Callable, in_target_state: Callable, in_sink_state: Callable ) -> float: - """ - Calculates the mean first-passage time (MFPT) from the initial state to the target state for a list of trajectories. + """Calculates the mean first-passage time (MFPT) from the initial state to the target state for a list of trajectories. Args: trajectories: A list of trajectories. diff --git a/dynamo/prediction/utils.py b/dynamo/prediction/utils.py index 8f9baad01..ba46ce10e 100644 --- a/dynamo/prediction/utils.py +++ b/dynamo/prediction/utils.py @@ -1,7 +1,8 @@ -from typing import Callable, Union +from typing import Callable, List, Optional, Tuple, Union # from anndata._core.views import ArrayView import numpy as np +from anndata import AnnData from scipy import interpolate from scipy.integrate import solve_ivp from tqdm import tqdm @@ -60,18 +61,34 @@ def init_l0_chase( # --------------------------------------------------------------------------------------------------- # integration related def integrate_vf_ivp( - init_states, - t, - integration_direction, + init_states: np.ndarray, + t: np.ndarray, + integration_direction: str, f: Callable, - args=None, - interpolation_num=250, - average=True, - sampling="arc_length", - verbose=False, - disable=False, -): - """integrating along vector field function using the initial value problem solver from scipy.integrate""" + args: Optional[Tuple] = None, + interpolation_num: int = 250, + average: bool = True, + sampling: str = "arc_length", + verbose: bool = False, + disable: bool = False, +) -> Tuple[np.ndarray, np.ndarray]: + """Integrating along vector field function using the initial value problem solver from scipy.integrate. + + Args: + init_states: Initial states of the system. + t: Time points to integrate the system over. + integration_direction: The direction of integration. + f: The vector field function of the system. + args: Additional arguments to pass to the vector field function. + interpolation_num: Number of time points to interpolate the trajectories over. + average: Whether to average the trajectories. + sampling: The method of sampling points along a trajectory. + verbose: Whether to print the integration time. + disable: Whether to disable the progress bar. + + Returns: + The time and trajectories of the system. + """ # TODO: rewrite this function with the Trajectory class if init_states.ndim == 1: @@ -241,7 +258,28 @@ def integrate_vf_ivp( return t, Y -def integrate_sde(init_states, t, f, sigma, num_t=100, **interp_kwargs): +def integrate_sde( + init_states: Union[np.ndarray, list], + t: Union[float, np.ndarray], + f: Callable, + sigma: Union[float, np.ndarray, Callable], + num_t: int = 100, + **interp_kwargs, +) -> np.ndarray: + """Calculate the trajectories by integrating a system of stochastic differential equations (SDEs) using the sdeint + package. + + Args: + init_states: Initial states of the system. + t: Time points to integrate the system over. + f: The vector field function of the system. + sigma: The diffusion matrix of the system. + num_t: Number of time points to interpolate the trajectories over. + interp_kwargs: Additional keyword arguments to pass to the interpolation function. + + Returns: + The trajectories of the system. + """ try: from sdeint import itoint except: @@ -278,7 +316,25 @@ def integrate_sde(init_states, t, f, sigma, num_t=100, **interp_kwargs): return np.array(trajs) -def estimate_sigma(X, V, diff_multiplier=1.0, num_nbrs=30, nbr_idx=None): +def estimate_sigma( + X: np.ndarray, + V: np.ndarray, + diff_multiplier: int = 1.0, + num_nbrs: int = 30, + nbr_idx: Optional[np.ndarray] = None, +) -> np.ndarray: + """Estimate the diffusion matrix of the system using the vector field and the data. + + Args: + X: The array representing cell states. + V: The array representing velocity. + diff_multiplier: The multiplier for the diffusion matrix. + num_nbrs: The number of nearest neighbors to use for the estimation. + nbr_idx: The indices of the nearest neighbors. + + Returns: + The estimated diffusion matrix. + """ if nbr_idx is None: nbr_idx = nearest_neighbors(X, X, k=num_nbrs) @@ -295,16 +351,30 @@ def estimate_sigma(X, V, diff_multiplier=1.0, num_nbrs=30, nbr_idx=None): def integrate_streamline( - X, - Y, - U, - V, - integration_direction, - init_states, - interpolation_num=100, - average=True, -): - """use streamline's integrator to alleviate stacking of the solve_ivp. Need to update with the correct time.""" + X: np.ndarray, + Y: np.ndarray, + U: np.ndarray, + V: np.ndarray, + integration_direction: str, + init_states: np.ndarray, + interpolation_num: int = 100, + average: bool = True, +) -> np.ndarray: + """Use streamline's integrator to alleviate stacking of the solve_ivp. Need to update with the correct time. + + Args: + X: The x-coordinates of the grid. + Y: The y-coordinates of the grid. + U: The x-components of the velocity. + V: The y-components of the velocity. + integration_direction: The direction of integration. + init_states: The initial states of the system. + interpolation_num: The number of time points to interpolate the trajectories over. + average: Whether to average the trajectories. + + Returns: + The time and trajectories of the system. + """ import matplotlib.pyplot as plt n_cell = init_states.shape[0] @@ -355,7 +425,31 @@ def integrate_streamline( # --------------------------------------------------------------------------------------------------- # fate related -def fetch_exprs(adata, basis, layer, genes, time, mode, project_back_to_high_dim, traj_ind): +def fetch_exprs( + adata: AnnData, + basis: str, + layer: str, + genes: Union[str, list], + time: str, + mode: str, + project_back_to_high_dim: bool, + traj_ind: int, +) -> Tuple: + """Fetch the expression data for the given genes and time points. + + Args: + adata: The AnnData object. + basis: Target basis to fetch. + layer: Target layer to fetch. + genes: Target genes to consider. + time: The time information. + mode: The mode of the trajectory. + project_back_to_high_dim: Whether to project the data back to high dimension. + traj_ind: The index of the trajectory. + + Returns: + The expression data for the given genes and time points. + """ if type(genes) != list: genes = list(genes) @@ -412,14 +506,33 @@ def fetch_exprs(adata, basis, layer, genes, time, mode, project_back_to_high_dim # perturbation related -def z_score(X, axis=1): +def z_score(X: np.ndarray, axis: int = 1) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Calculate the z-score of the given data. + + Args: + X: The data to calculate the z-score for. + axis: The axis to calculate the z-score over. + + Returns: + The z-score of the data. + """ s = X.std(axis) m = X.mean(axis) Z = ((X.T - m) / s).T return Z, m, s -def z_score_inv(Z, m, s): +def z_score_inv(Z: Union[float, List, np.ndarray], m: np.ndarray, s: np.ndarray) -> np.ndarray: + """The inverse operation of z-score calculation. + + Args: + Z: The z-scored data. + m: The mean of the original data. + s: The standard deviation of the original data. + + Returns: + The original data reconstructed from the z-scored data. + """ if isarray(Z): X = (Z.T * s + m).T else: @@ -431,7 +544,17 @@ def z_score_inv(Z, m, s): # state graph related -def get_path(Pr, i, j): +def get_path(Pr: np.ndarray, i: int, j: int) -> List: + """Retrieve the shortest path from node i to node j in given graph. + + Args: + Pr: The graph. + i: The start node. + j: The end node. + + Returns: + The shortest path from node i to node j. + """ path = [j] k = j while Pr[i, k] != -9999: @@ -444,10 +567,25 @@ def get_path(Pr, i, j): # least action path related -def interp_second_derivative(t, f, num=5e2, interp_kind="cubic", **interp_kwargs): - """ - interpolate f(t) and calculate the discrete second derivative using: +def interp_second_derivative( + t: np.ndarray, + f: np.ndarray, + num: int = 5e2, + interp_kind="cubic", + **interp_kwargs, +) -> Tuple[np.ndarray, np.ndarray]: + """Interpolate f(t) and calculate the discrete second derivative using: d^2 f / dt^2 = (f(x+h1) - 2f(x) + f(x-h2)) / (h1 * h2) + + Args: + t: The time points. + f: The function values corresponding to the time points. + num: The number of points to interpolate to. + interp_kind: The kind of interpolation to use. + interp_kwargs: Additional keyword arguments to pass to the interpolation function. + + Returns: + The interpolated time points and the discrete second derivative. """ t_ = np.linspace(t[0], t[-1], int(num)) f_ = interpolate.interp1d(t, f, kind=interp_kind, **interp_kwargs)(t_) @@ -463,8 +601,25 @@ def interp_second_derivative(t, f, num=5e2, interp_kind="cubic", **interp_kwargs return t_, d2fdt2 -def interp_curvature(t, f, num=5e2, interp_kind="cubic", **interp_kwargs): - """""" +def interp_curvature( + t: np.ndarray, + f: np.ndarray, + num: int = 5e2, + interp_kind="cubic", + **interp_kwargs, +) -> Tuple[np.ndarray, np.ndarray]: + """Interpolate f(t) and calculate the curvature. + + Args: + t: The time points. + f: The function values corresponding to the time points. + num: The number of points to interpolate to. + interp_kind: The kind of interpolation to use. + interp_kwargs: Additional keyword arguments to pass to the interpolation function. + + Returns: + The interpolated time points and the curvature. + """ t_ = np.linspace(t[0], t[-1], int(num)) f_ = interpolate.interp1d(t, f, kind=interp_kind, **interp_kwargs)(t_) @@ -484,7 +639,17 @@ def interp_curvature(t, f, num=5e2, interp_kind="cubic", **interp_kwargs): return t_, cur -def kneedle_difference(t, f, type="decrease"): +def kneedle_difference(t: np.ndarray, f: np.ndarray, type: str = "decrease") -> np.ndarray: + """Calculate the difference between the function and the diagonal line. + + Args: + t: The time points. + f: The function values corresponding to the time points. + type: The type of function to use. + + Returns: + The difference between the function and the diagonal line. + """ if type == "decrease": diag_line = lambda x: -x + 1 elif type == "increase": @@ -498,7 +663,25 @@ def kneedle_difference(t, f, type="decrease"): return res -def find_elbow(T, F, method="kneedle", order=1, **kwargs): +def find_elbow( + T: np.ndarray, + F: np.ndarray, + method: str = "kneedle", + order: int = 1, + **kwargs, +) -> int: + """Find the elbow of the given function. + + Args: + T: The time points. + F: The function values corresponding to the time points. + method: The method to use for finding the elbow. + order: The order of the elbow. + kwargs: Additional keyword arguments to pass to the elbow finding function. + + Returns: + The index of the elbow. + """ i_elbow = None if method == "hessian": T_ = normalize(T)