-
Notifications
You must be signed in to change notification settings - Fork 59
/
fate.py
executable file
·716 lines (622 loc) · 33 KB
/
fate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
import itertools
import warnings
from multiprocessing.dummy import Pool as ThreadPool
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
from anndata import AnnData
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm
from ..configuration import DKM
from ..dynamo_logger import (
LoggerManager,
main_info,
main_info_insert_adata,
main_warning,
)
from ..utils import pca_to_expr
from ..tools.connectivity import construct_mapper_umap, correct_hnsw_neighbors, k_nearest_neighbors
from ..tools.utils import fetch_states, getTseq
from ..vectorfield import vector_field_function
from ..vectorfield.utils import vecfld_from_adata, vector_transformation
from .utils import integrate_vf_ivp
def fate(
adata: AnnData,
init_cells: list,
init_states: Optional[np.ndarray] = None,
basis: Optional[None] = None,
layer: str = "X",
dims: Optional[Union[int, List[int], Tuple[int], np.ndarray]] = None,
genes: Optional[List] = None,
t_end: Optional[float] = None,
direction: str = "both",
interpolation_num: int = 250,
average: bool = False,
sampling: str = "arc_length",
VecFld_true: Callable = None,
inverse_transform: bool = False,
Qkey: str = "PCs",
scale: float = 1,
cores: int = 1,
**kwargs: dict,
) -> AnnData:
"""Predict the historical and future cell transcriptomic states over arbitrary time scales.
This is achieved by integrating the reconstructed vector field function from one or a set of initial cell state(s).
Note that this function is designed so that there is only one trajectory (based on averaged cell states if multiple
initial states are provided) will be returned. `dyn.tl._fate` can be used to calculate multiple cell states.
Args:
adata: AnnData object that contains the reconstructed vector field function in the `uns` attribute.
init_cells: Cell name or indices of the initial cell states for the historical or future cell state prediction with
numerical integration. If the names in init_cells not found in the adata.obs_name, it will be treated as
cell indices and must be integers.
init_states: Initial cell states for the historical or future cell state prediction with numerical integration.
basis: The embedding data to use for predicting cell fate. If `basis` is either `umap` or `pca`, the reconstructed
trajectory will be projected back to high dimensional space via the `inverse_transform` function.
layer: Which layer of the data will be used for predicting cell fate with the reconstructed vector field function.
The layer once provided, will override the `basis` argument and then predicting cell fate in high
dimensional space.
dims: The dimensions that will be selected for fate prediction.
genes: The gene names whose gene expression will be used for predicting cell fate. By default (when genes is set to
None), the genes used for velocity embedding (var.use_for_transition) will be used for vector field
reconstruction. Note that the genes to be used need to have velocity calculated and corresponds to those
used in the `dyn.tl.VectorField` function.
t_end: The length of the time period from which to predict cell state forward or backward over time. This is used
by the odeint function.
direction: The direction to predict the cell fate. One of the `forward`, `backward` or `both` string.
interpolation_num: The number of uniformly interpolated time points.
average: The method to calculate the average cell state at each time step, can be one of `origin` or `trajectory`. If
`origin` used, the average expression state from the init_cells will be calculated and the fate prediction
is based on this state. If `trajectory` used, the average expression states of all cells predicted from the
vector field function at each time point will be used. If `average` is `False`, no averaging will be
applied. If `average` is True, `origin` will be used.
sampling: Methods to sample points along the integration path, one of `{'arc_length', 'logspace', 'uniform_indices'}`.
If `logspace`, we will sample time points linearly on log space. If `uniform_indices`, the sorted unique set
of all time points from all cell states' fate prediction will be used and then evenly sampled up to
`interpolation_num` time points. If `arc_length`, we will sample the integration path with uniform arc
length.
VecFld_true: The true ODE function, useful when the data is generated through simulation. Replace VecFld argument when
this has been set.
inverse_transform: Whether to inverse transform the low dimensional vector field prediction back to high dimensional space.
Qkey: The key of the PCA loading matrix in `.uns`.
scale: The value that will be used to scale the predicted velocity value from the reconstructed vector field
function.
cores: Number of cores to calculate path integral for predicting cell fate. If cores is set to be > 1,
multiprocessing will be used to parallel the fate prediction.
kwargs: Additional parameters that will be passed into the fate function.
Returns:
AnnData object that is updated with the dictionary Fate (includes `t` and `prediction` keys) in uns attribute.
"""
if basis is not None:
fate_key = "fate_" + basis
# vf_key = "VecFld_" + basis
else:
fate_key = "fate" if layer == "X" else "fate_" + layer
# vf_key = "VecFld"
# VecFld = adata.uns[vf_key]["VecFld"]
# X = VecFld["X"]
# xmin, xmax = X.min(0), X.max(0)
# t_end = np.max(xmax - xmin) / np.min(np.abs(VecFld["V"]))
# valid_genes = None
init_states, VecFld, t_end, valid_genes = fetch_states(
adata,
init_states,
init_cells,
basis,
layer,
average,
t_end,
)
if np.isscalar(dims):
init_states = init_states[:, :dims]
elif dims is not None:
init_states = init_states[:, dims]
vf = (
(lambda x: scale * vector_field_function(x=x, vf_dict=VecFld, dim=dims)) if VecFld_true is None else VecFld_true
)
t, prediction = _fate(
vf,
init_states,
t_end=t_end,
direction=direction,
interpolation_num=interpolation_num,
average=True if average == "trajectory" else False,
sampling=sampling,
cores=cores,
**kwargs,
)
exprs = None
if inverse_transform:
exprs, valid_genes = _inverse_transform(adata=adata, prediction=prediction, basis=basis, Qkey=Qkey)
adata.uns[fate_key] = {
"init_states": init_states,
"init_cells": list(init_cells),
"average": average,
"t": t,
"prediction": prediction,
# "VecFld": VecFld,
# "VecFld_true": VecFld_true,
"genes": valid_genes,
}
if exprs is not None:
adata.uns[fate_key]["exprs"] = exprs
return adata
def _fate(
VecFld: Callable,
init_states: np.ndarray,
t_end: Optional[float] = None,
step_size: Optional[float] = None,
direction: str = "both",
interpolation_num: int = 250,
average: bool = True,
sampling: str = "arc_length",
cores: int = 1,
) -> Tuple[np.ndarray, np.ndarray]:
"""Predict the historical and future cell transcriptomic states over arbitrary time scales by integrating vector
field functions from one or a set of initial cell state(s).
Args:
VecFld: Functional form of the vector field reconstructed from sparse single cell samples. It is applicable to the
entire transcriptomic space.
init_states: Initial cell states for the historical or future cell state prediction with numerical integration.
t_end: The length of the time period from which to predict cell state forward or backward over time. This is used
by the odeint function.
step_size: Step size for integrating the future or history cell state, used by the odeint function. By default it is
None, and the step_size will be automatically calculated to ensure 250 total integration time-steps will be
used.
direction: The direction to predict the cell fate. One of the `forward`, `backward`or `both` string.
interpolation_num: The number of uniformly interpolated time points.
average: A boolean flag to determine whether to smooth the trajectory by calculating the average cell state at each
time step.
sampling: Methods to sample points along the integration path, one of `{'arc_length', 'logspace', 'uniform_indices'}`.
If `logspace`, we will sample time points linearly on log space. If `uniform_indices`, the sorted unique set
of all time points from all cell states' fate prediction will be used and then evenly sampled up to
`interpolation_num` time points. If `arc_length`, we will sample the integration path with uniform arc
length.
cores: Number of cores to calculate path integral for predicting cell fate. If cores is set to be > 1,
multiprocessing will be used to parallel the fate prediction.
Returns:
A tuple containing two elements:
t: The time at which the cell state are predicted.
prediction: Predicted cells states at different time points. Row order corresponds to the element order in
t. If init_states corresponds to multiple cells, the expression dynamics over time for each cell is
concatenated by rows. That is, the final dimension of prediction is (len(t) * n_cells, n_features).
n_cells: number of cells; n_features: number of genes or number of low dimensional embeddings.
Of note, if the average is set to be True, the average cell state at each time point is calculated for
all cells.
"""
if sampling == "uniform_indices":
main_warning(
f"Uniform_indices method sample data points from all time points. The multiprocessing will be disabled."
)
cores = 1
t_linspace = getTseq(init_states, t_end, step_size)
if cores == 1:
t, prediction = integrate_vf_ivp(
init_states,
t_linspace,
direction,
VecFld,
interpolation_num=interpolation_num,
average=average,
sampling=sampling,
)
else:
pool = ThreadPool(cores)
res = pool.starmap(
integrate_vf_ivp,
zip(
init_states,
itertools.repeat(t_linspace),
itertools.repeat(direction),
itertools.repeat(VecFld),
itertools.repeat(()),
itertools.repeat(interpolation_num),
itertools.repeat(average),
itertools.repeat(sampling),
itertools.repeat(False),
itertools.repeat(True),
),
) # disable tqdm when using multiple cores.
pool.close()
pool.join()
t_, prediction_ = zip(*res)
t, prediction = [i[0] for i in t_], [i[0] for i in prediction_]
if init_states.shape[0] > 1 and average:
t_stack, prediction_stack = np.hstack(t), np.hstack(prediction)
n_cell, n_feature = init_states.shape
t_len = int(len(t_stack) / n_cell)
avg = np.zeros((n_feature, t_len))
for i in range(t_len):
avg[:, i] = np.mean(prediction_stack[:, np.arange(n_cell) * t_len + i], 1)
prediction = [avg]
t = [np.sort(np.unique(t))]
return t, prediction
def _inverse_transform(
adata: AnnData,
prediction: Union[np.ndarray, List[np.ndarray]],
basis: str = "umap",
Qkey: str = "PCs",
) -> Tuple[Union[np.ndarray, List[np.ndarray]], np.ndarray]:
"""Inverse transform the low dimensional vector field prediction back to high dimensional space.
Args:
adata: AnnData object that contains the reconstructed vector field function in the `uns` attribute.
prediction: Predicted cells states at different time points.
basis: The embedding data to use for predicting cell fate.
Qkey: The key of the PCA loading matrix in `.uns`.
Returns:
The predicted cells states at different time points in high dimensional space and the gene names whose gene
expression will be used for predicting cell fate.
"""
if basis == "pca":
if type(prediction) == list:
exprs = [vector_transformation(cur_pred.T, adata.uns[Qkey]) for cur_pred in prediction]
high_p_n = exprs[0].shape[1]
else:
exprs = vector_transformation(prediction.T, adata.uns[Qkey])
high_p_n = exprs.shape[1]
if adata.var.use_for_dynamics.sum() == high_p_n:
valid_genes = adata.var_names[adata.var.use_for_dynamics]
else:
valid_genes = adata.var_names[adata.var.use_for_transition]
elif basis == "umap":
# this requires umap 0.4; reverse project to PCA space.
if hasattr(prediction, "ndim"):
if prediction.ndim == 1:
prediction = prediction[None, :]
params = adata.uns["umap_fit"]
umap_fit = construct_mapper_umap(
params["X_data"],
n_components=params["umap_kwargs"]["n_components"],
metric=params["umap_kwargs"]["metric"],
min_dist=params["umap_kwargs"]["min_dist"],
spread=params["umap_kwargs"]["spread"],
max_iter=params["umap_kwargs"]["max_iter"],
alpha=params["umap_kwargs"]["alpha"],
gamma=params["umap_kwargs"]["gamma"],
negative_sample_rate=params["umap_kwargs"]["negative_sample_rate"],
init_pos=params["umap_kwargs"]["init_pos"],
random_state=params["umap_kwargs"]["random_state"],
umap_kwargs=params["umap_kwargs"],
)
PCs = adata.uns[Qkey].T
exprs = []
for cur_pred in prediction:
expr = umap_fit.inverse_transform(cur_pred.T)
# further reverse project back to raw expression space
if PCs.shape[0] == expr.shape[1]:
expr = np.expm1(expr @ PCs + adata.uns["pca_mean"])
exprs.append(expr)
if adata.var.use_for_dynamics.sum() == exprs[0].shape[1]:
valid_genes = adata.var_names[adata.var.use_for_dynamics]
elif adata.var.use_for_transition.sum() == exprs[0].shape[1]:
valid_genes = adata.var_names[adata.var.use_for_transition]
else:
raise Exception(
"looks like a customized set of genes is used for pca analysis of the adata. "
"Try rerunning pca analysis with default settings for this function to work."
)
else:
raise ValueError(f"Inverse transform with basis {basis} is not supported.")
return exprs, valid_genes
def fate_bias(
adata: AnnData,
group: str,
basis: str = "umap",
inds: Union[list, None] = None,
use_sink_percentage: bool = True,
step_used_percentage: Optional[float] = None,
speed_percentile: float = 5,
dist_threshold: Optional[float] = None,
source_groups: Optional[list] = None,
metric: str = "euclidean",
metric_kwds: dict = None,
cores: int = 1,
seed: int = 19491001,
**kwargs,
) -> pd.DataFrame:
"""Calculate the lineage (fate) bias of states whose trajectory are predicted.
Fate bias is currently calculated as the percentage of points along the predicted cell fate trajectory whose
distance to their 0-th nearest neighbors on the data are close enough (determined by median 1-st nearest neighbors
of all observed cells and the dist_threshold) to any cell from each group specified by `group` key. The details is
described as following:
Cell fate predicted by our vector field method sometimes end up in regions that are not sampled with cells. We thus
developed a heuristic method to iteratively walk backward the integration path to assign cell fate. We first
identify the regions with small velocity in the tail of the integration path (determined by `speed_percentile`),
then we check whether the distance of 0-th nearest points on the observed data to all those points are far away from
the observed data (determined by `dist_threshold`). If they are not all close to data, we then walk backwards along
the trajectory by one time step until the distance of any currently visited integration path’s data points’ 0-th
nearest points to the observed cells is close enough. In order to calculate the cell fate probability, we diffuse
one step further of the identified nearest neighbors from the integration to identify more nearest observed cells,
especially those from terminal cell types in case nearby cells first identified are all close to some random
progenitor cells. Then we use group information of those observed cells to define the fate probability.
`fate_bias` calculate a confidence score for the calculated fate probability with a simple metric, defined as
:math:`1 - (sum(distances > dist_threshold * median_dist) + walk_back_steps) / (len(indices) + walk_back_steps)`
The `distance` is currently visited integration path’s data points’ 0-th nearest points to the observed cells.
`median_dist` is median distance of their 1-st nearest cell distance of all observed cells. `walk_back_steps` is the
steps walked backward along the integration path until all currently visited integration points's 0-th nearest
points to the observed cells satisfy the distance threshold. `indices` are the time indices of integration points
that is regarded as the regions with `small velocity` (note when walking backward, those corresponding points do
not necessarily have small velocity anymore).
Args:
adata: AnnData object that contains the predicted fate trajectories in the `uns` attribute.
group: The column key that corresponds to the cell type or other group information for quantifying the bias of cell
state.
basis: The embedding data space where cell fates were predicted and cell fates bias will be quantified.
inds: The indices of the time steps that will be used for calculating fate bias.
Otherwise inds need to be a list of integers of the time steps.
use_sink_percentage: If inds is None and use_sink is True, sink calculation will be applied to calculate
indices used for fate bias calculation
step_used_percentage: If inds is None and step_used_percentage is not None,
step_used_percentage will be regarded as a percentage,
and the LAST step_used_percentage of steps will be used for fate bias calculation.
speed_percentile: The percentile of speed that will be used to determine the terminal cells (or sink region on the prediction
path where speed is smaller than this speed percentile).
dist_threshold: A multiplier of the median nearest cell distance on the embedding to determine cells that are outside the
sampled domain of cells. If the mean distance of identified "terminal cells" is above this number, we will
look backward along the trajectory (by minimize all indices by 1) until it finds cells satisfy this
threshold. By default it is set to be 1 to ensure only considering points that are very close to observed
data points.
source_groups: The groups that corresponds to progenitor groups. They need to have at least one intersection with the groups
from the `group` column. If group is not `None`, any identified "source_groups" cells that happen to be in
those groups will be ignored and the probability of cell fate of those cells will be reassigned to the group
that has the highest fate probability among other non source_groups group cells.
metric: The distance metric to use for the tree. The default metric with p=2 is equivalent to the standard
Euclidean metric. See the documentation of :class:`DistanceMetric` for a list of available metrics. If
metric is "precomputed", X is assumed to be a distance matrix and must be square during fit. X may be a
:term:`sparse graph`, in which case only "nonzero" elements may be considered neighbors.
metric_kwds : Additional keyword arguments for the metric function.
cores: The number of parallel jobs to run for neighbors search. ``None`` means 1 unless in a
:obj:`joblib.parallel_backend` context. ``-1`` means using all processors.
seed: Random seed to ensure the reproducibility of each run.
kwargs: Additional arguments that will be passed to each nearest neighbor search algorithm.
Returns:
fate_bias: A DataFrame that stores the fate bias for each cell state (row) to each cell group (column).
"""
if dist_threshold is None:
dist_threshold = 1
if group not in adata.obs.keys():
raise ValueError(f"The group {group} you provided is not a key of .obs attribute.")
else:
clusters = adata.obs[group]
basis_key = "X_" + basis if basis is not None else "X"
fate_key = "fate_" + basis if basis is not None else "fate"
if basis_key not in adata.obsm.keys():
raise ValueError(f"The basis {basis_key} you provided is not a key of .obsm attribute.")
if fate_key not in adata.uns.keys():
raise ValueError(
f"The {fate_key} key is not existed in the .uns attribute of the adata object. You need to run"
f"dyn.pd.fate(adata, basis='{basis}') before calculate fate bias."
)
if source_groups is not None:
if type(source_groups) is str:
source_groups = [source_groups]
source_groups = list(set(source_groups).intersection(clusters))
if len(source_groups) == 0:
raise ValueError(
f"the {source_groups} you provided doesn't intersect with any groups in the {group} column."
)
X = adata.obsm[basis_key] if basis_key != "X" else adata.X
knn, distances, nbrs, alg = k_nearest_neighbors(
X,
k=29,
metric=metric,
metric_kwads=metric_kwds,
exclude_self=False,
pynn_rand_state=seed,
return_nbrs=True,
n_jobs=cores,
**kwargs,
)
median_dist = np.median(distances[:, 1])
pred_dict = {}
cell_predictions, cell_indx = (
adata.uns[fate_key]["prediction"],
adata.uns[fate_key]["init_cells"],
)
t = adata.uns[fate_key]["t"]
confidence = np.zeros(len(t))
for i, prediction in tqdm(enumerate(cell_predictions), desc="calculating fate distributions"):
cur_t, n_steps = t[i], len(t[i])
# Generate or set indices as step sample points. Meanwhile ensure
# identifying sink where the speed is very slow. If "inds" is set, use "inds" and the speed_percentile is used to determine the time indicies for calculating the fate bias
# else if "use_sink_percentage" is set, calculate avg_speed and sink_checker
# else if "step_used_percentage" is set, use the last percentage of steps to check for cell fate bias.
# If none of the above arguments are set, use a list of n steps as indices
if inds is not None:
indices = inds
elif use_sink_percentage:
avg_speed = np.array([np.linalg.norm(i) for i in np.diff(prediction, 1).T]) / np.diff(cur_t)
sink_checker = np.where(avg_speed[::-1] > np.percentile(avg_speed, speed_percentile))[0]
indices = np.arange(n_steps - max(min(sink_checker), 10), n_steps)
elif step_used_percentage is float:
indices = np.arange(int(n_steps - step_used_percentage * n_steps), n_steps)
else:
main_info("using all steps data")
indices = np.arange(0, n_steps)
if alg == "pynn":
knn, distances = nbrs.query(prediction[:, indices].T, k=30)
elif alg == "hnswlib":
knn, distances = nbrs.knn_query(prediction[:, indices].T, k=30)
if metric == "euclidean":
distances = np.sqrt(distances)
knn, distances = correct_hnsw_neighbors(knn, distances)
else:
distances, knn = nbrs.kneighbors(prediction[:, indices].T)
# if final steps too far away from observed cells, ignore them
walk_back_steps = 0
while True:
is_dist_larger_than_threshold = distances.flatten() < dist_threshold * median_dist
if any(is_dist_larger_than_threshold):
# let us diffuse one step further to identify cells from terminal cell types in case
# cells with indices are all close to some random progenitor cells.
if hasattr(nbrs, "query"):
knn, _ = nbrs.query(X[knn.flatten(), :], k=30)
elif hasattr(nbrs, "knn_query"):
knn, distances_hn = nbrs.knn_query(X[knn.flatten(), :], k=30)
knn, _ = correct_hnsw_neighbors(knn, distances_hn)
else:
_, knn = nbrs.kneighbors(X[knn.flatten(), :])
fate_prob = clusters[knn.flatten()].value_counts() / len(knn.flatten())
if source_groups is not None:
source_p = fate_prob[source_groups].sum()
if 1 > source_p > 0:
fate_prob[source_groups] = 0
fate_prob[fate_prob.idxmax()] += source_p
pred_dict[i] = fate_prob
confidence[i] = 1 - (sum(~is_dist_larger_than_threshold) + walk_back_steps) / (
len(is_dist_larger_than_threshold) + walk_back_steps
)
break
else:
walk_back_steps += 1
if any(indices - 1 < 0):
pred_dict[i] = clusters[knn.flatten()].value_counts() * np.nan
break
if hasattr(nbrs, "query"):
knn, distances = nbrs.query(prediction[:, indices - 1].T, k=30)
elif hasattr(nbrs, "knn_query"):
knn, distances = nbrs.knn_query(prediction[:, indices - 1].T, k=30)
if metric == "euclidean":
distances = np.sqrt(distances)
knn, distances = correct_hnsw_neighbors(knn, distances)
else:
distances, knn = nbrs.kneighbors(prediction[:, indices - 1].T)
knn, distances = knn[:, 0], distances[:, 0]
indices = indices - 1
bias = pd.DataFrame(pred_dict).T
conf = pd.DataFrame({"confidence": confidence}, index=bias.index)
bias = pd.merge(conf, bias, left_index=True, right_index=True)
if cell_indx is not None:
bias.index = cell_indx
return bias
# def fate_(adata, time, direction = 'forward'):
# from .moments import *
# gene_exprs = adata.X
# cell_num, gene_num = gene_exprs.shape
#
#
# for i in range(gene_num):
# params = {'a': adata.uns['dynamo'][i, "a"], \
# 'b': adata.uns['dynamo'][i, "b"], \
# 'la': adata.uns['dynamo'][i, "la"], \
# 'alpha_a': adata.uns['dynamo'][i, "alpha_a"], \
# 'alpha_i': adata.uns['dynamo'][i, "alpha_i"], \
# 'sigma': adata.uns['dynamo'][i, "sigma"], \
# 'beta': adata.uns['dynamo'][i, "beta"], \
# 'gamma': adata.uns['dynamo'][i, "gamma"]}
# mom = moments_simple(**params)
# for j in range(cell_num):
# x0 = gene_exprs[i, j]
# mom.set_initial_condition(*x0)
# if direction == "forward":
# gene_exprs[i, j] = mom.solve([0, time])
# elif direction == "backward":
# gene_exprs[i, j] = mom.solve([0, - time])
#
# adata.uns['prediction'] = gene_exprs
# return adata
def andecestor(
adata: AnnData,
init_cells: List,
init_states: Optional[np.ndarray] = None,
cores: int = 1,
t_end: int = 50,
basis: str = "umap",
n_neighbors: int = 5,
direction: str = "forward",
interpolation_num: int = 250,
last_point_only: bool = False,
metric: str = "euclidean",
metric_kwds: dict = None,
seed: int = 19491001,
**kwargs,
) -> None:
"""Predict the ancestors or descendants of a group of initial cells (states) with the given vector field function.
Args:
adata: AnnData object that contains the reconstructed vector field function in the `uns` attribute.
init_cells: Cell name or indices of the initial cell states for the historical or future cell state prediction with
numerical integration. If the names in init_cells not found in the adata.obs_name, it will be treated as
cell indices and must be integers.
init_states: Initial cell states for the historical or future cell state prediction with numerical integration.
basis: The key in `adata.obsm` that points to the embedding data to use for predicting cell fate.
cores: Number of cores to calculate nearest neighbor graph.
t_end: The length of the time period from which to predict cell state forward or backward over time. This is used
by the odeint function.
n_neighbors: Number of nearest neighbors.
direction: The direction to predict the cell fate. One of the `forward`, `backward` or `both` string.
interpolation_num: The number of uniformly interpolated time points.
metric: The distance metric to use for the tree. The default metric is 'euclidean', and with p=2 is
equivalent to the standard Euclidean metric. See the documentation of :class:`DistanceMetric` for
a list of available metrics. If metric is "precomputed", X is assumed to be a distance matrix and
must be square during fit. X may be a :term:`sparse graph`, in which case only "nonzero" elements
may be considered neighbors.
metric_kwds : Additional keyword arguments for the metric function.
seed: Random seed to ensure the reproducibility of each run.
kwargs: Additional arguments that will be passed to each nearest neighbor search algorithm.
Returns:
Nothing but update the adata object with a new column in `.obs` that stores predicted ancestors or descendants.
"""
logger = LoggerManager.gen_logger("dynamo-andecestor")
logger.log_time()
main_info("retrieve vector field function.")
vec_dict, vecfld = vecfld_from_adata(adata, basis=basis)
basis_key = "X_" + basis
X = adata.obsm[basis_key].copy()
main_info("build a kNN graph structure so we can query the nearest cells of the predicted states.")
_, _, nbrs, alg = k_nearest_neighbors(
X,
k=n_neighbors - 1,
metric=metric,
metric_kwads=metric_kwds,
exclude_self=False,
pynn_rand_state=seed,
n_jobs=cores,
return_nbrs=True,
logger=logger,
**kwargs,
)
if init_states is None:
init_states = adata[init_cells, :].obsm[basis_key]
else:
if init_states.shape[1] != adata.obsm[basis_key].shape[1]:
raise Exception(
f"init_states has to have the same columns as adata.obsm[{basis_key}] but you have "
f"{init_states.shape[1]}"
)
main_info("predict cell state trajectory via integrating vector field function.")
t, pred = _fate(
vecfld,
init_states,
t_end=t_end,
interpolation_num=interpolation_num,
average=False,
sampling="arc_length",
cores=cores,
direction=direction,
)
nearest_cell_inds = []
main_info("identify the progenitors/descendants by finding predicted cell states' nearest cells.")
for j in range(len(pred)):
last_indices = [0, -1] if direction == "both" else [-1]
queries = pred[j].T[last_indices] if last_point_only else pred[j].T
if alg == "pynn":
knn, distances = nbrs.query(queries, k=n_neighbors)
elif alg == "hnswlib":
knn, distances = nbrs.knn_query(queries, k=n_neighbors)
if metric == "euclidean":
distances = np.sqrt(distances)
knn, distances = correct_hnsw_neighbors(knn, distances)
else:
distances, knn = nbrs.kneighbors(queries)
nearest_cell_inds += list(knn.flatten())
nearest_cell_inds = np.unique(nearest_cell_inds)
if init_cells is not None:
if type(init_cells[0]) is int:
init_cells = adata.obs_names[init_cells]
nearest_cells = list(set(adata.obs_names[nearest_cell_inds]).difference(init_cells))
else:
nearest_cells = list(adata.obs_names[nearest_cell_inds])
obs_key = "descendant" if direction == "forward" else "ancestor" if direction == "backward" else "lineage"
main_info_insert_adata(obs_key)
adata.obs[obs_key] = False
adata.obs.loc[nearest_cells, obs_key] = True
logger.finish_progress(progress_name=f"predict {obs_key}")