Skip to content

Commit

Permalink
Merge pull request #564 from Sichao25/tool
Browse files Browse the repository at this point in the history
Implement and debug DDRTree based methods
  • Loading branch information
Xiaojieqiu authored Oct 30, 2023
2 parents e5e9017 + 35b9a6e commit 4a5ac26
Show file tree
Hide file tree
Showing 8 changed files with 818 additions and 67 deletions.
2 changes: 2 additions & 0 deletions dynamo/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
show_fraction,
variance_explained,
)
from .pseudotime import plot_dim_reduced_direct_graph
from .scatters import scatters
from .scPotential import show_landscape
from .sctransform import sctransform_plot_fit, plot_residual_var
Expand Down Expand Up @@ -153,4 +154,5 @@
"hessian",
"sctransform_plot_fit",
"plot_residual_var",
"plot_dim_reduced_direct_graph",
]
202 changes: 200 additions & 2 deletions dynamo/plot/pseudotime.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,213 @@
from typing import Any, Dict, Tuple
import math
from typing import Any, Dict, List, Optional, Tuple, Union

try:
from typing import Literal
except ImportError:
from typing_extensions import Literal

import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from anndata import AnnData
from scipy.sparse import csr_matrix

from ..tools.utils import update_dict
from .utils import save_fig
from .utils import get_color_map_from_labels, save_fig


def _calculate_cells_mapping(
adata: AnnData,
group_key: str,
cell_proj_closest_vertex: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray, Dict]:
"""Calculate the distribution of cells in each node.
Args:
adata: the anndata object.
group_key: the key to locate the groups of each cell in adata.
cell_proj_closest_vertex: the mapping from each cell to the corresponding node.
Returns:
The size of each node, the percentage of each group in every node, and the color mapping of each group.
"""
cells_mapping_size = np.bincount(cell_proj_closest_vertex)
centroids_index = range(len(cells_mapping_size))

cell_type_info = pd.DataFrame({
"class": adata.obs[group_key].values,
"centroid": cell_proj_closest_vertex,
})

cell_color_map = get_color_map_from_labels(adata.obs[group_key].values)

cell_type_info = cell_type_info.groupby(['centroid', 'class']).size().unstack()
cell_type_info = cell_type_info.reindex(centroids_index, fill_value=0)
cells_mapping_percentage = cell_type_info.div(cells_mapping_size, axis=0)
cells_mapping_percentage = np.nan_to_num(cells_mapping_percentage.values)

cells_mapping_size = (cells_mapping_size / len(cell_proj_closest_vertex))
cells_mapping_size = [0.05 if s < 0.05 else s for s in cells_mapping_size]

return cells_mapping_size, cells_mapping_percentage, cell_color_map


def _scale_positions(positions: np.ndarray, variance_scale: int = 1.5) -> np.ndarray:
"""Scale an array representing to the matplotlib coordinates system and scale the variance if needed.
Args:
positions: the array representing the positions of the data to plot.
variance_scale: the value to scale the variance of data.
Returns:
The positions after scaling.
"""
min_value = np.min(positions)
max_value = np.max(positions)
pos = (positions - min_value) / (max_value - min_value)
mean = np.mean(pos, axis=0)
pos = (pos - mean) * variance_scale
return pos


def plot_dim_reduced_direct_graph(
adata: AnnData,
group_key: Optional[str] = "Cell_type",
graph: Optional[Union[csr_matrix, np.ndarray]] = None,
cell_proj_closest_vertex: Optional[np.ndarray] = None,
center_coordinates: Optional[np.ndarray] = None,
display_piechart: bool = True,
variance_scale: int = 1.5,
save_show_or_return: Literal["save", "show", "return"] = "show",
save_kwargs: Dict[str, Any] = {},
) -> Optional[plt.Axes]:
"""Plot the directed graph constructed velocity-guided pseudotime.
Args:
adata: the anndata object.
group_key: the key to locate the groups of each cell in adata.
graph: the directed graph to plot.
cell_proj_closest_vertex: the mapping from each cell to the corresponding node.
center_coordinates: the array representing the positions of the center nodes in the low dimensions. Only need
this when display_piechart is True.
display_piechart: whether to display piechart for each node.
variance_scale: the value to scale the variance of data. This function is employed to space out the pie charts
when they are positioned too closely to each other.
save_show_or_return: whether to save, show or return the plot.
save_kwargs: additional keyword arguments of plot saving.
Returns:
The plot of the directed graph or `None`.
"""

try:
if graph is None:
graph = adata.uns["directed_velocity_tree"]

if cell_proj_closest_vertex is None:
cell_proj_closest_vertex = adata.uns["cell_order"]["pr_graph_cell_proj_closest_vertex"]
except KeyError:
raise KeyError("Cell order data is missing. Please run `tl.order_cells()` first!")

cells_size, cells_percentage, cells_color_map = _calculate_cells_mapping(
adata=adata,
group_key=group_key,
cell_proj_closest_vertex=cell_proj_closest_vertex,
)

cells_colors = np.array([v for v in cells_color_map.values()])

fig, ax = plt.subplots(figsize=(6, 6))

G = nx.from_numpy_array(graph, create_using=nx.DiGraph)

center_coordinates = adata.uns["cell_order"]["Y"].T.copy() if center_coordinates is None else center_coordinates
pos = _scale_positions(center_coordinates, variance_scale=variance_scale)
pos_dict = {}
for i in range(len(pos)):
pos_dict[i] = pos[i]

if display_piechart:

for node in G.nodes:
attributes = cells_percentage[node]

if np.all(attributes == 0):
plt.pie(
[1],
center=pos[node],
colors=[[0, 0, 0, 1]],
radius=cells_size[node],
)
else:
valid_indices = np.where(attributes != 0)[0]
plt.pie(
attributes[valid_indices],
center=pos[node],
colors=cells_colors[valid_indices],
radius=cells_size[node],
)
g = nx.draw_networkx_edges(
G,
pos=pos_dict,
node_size=[s * len(cells_size) * 300 for s in cells_size],
arrows=True,
arrowstyle="->",
arrowsize=20,
ax=ax,
)

else:
dominate_colors = []

for node in G.nodes:
attributes = cells_percentage[node]
if np.all(attributes == 0):
dominate_colors.append([0, 0, 0, 1])
else:
max_idx = np.argmax(attributes)
dominate_colors.append(cells_colors[max_idx])

nx.draw_networkx_nodes(G, pos=pos_dict, node_color=dominate_colors, node_size=[s * len(cells_size) * 300 for s in cells_size], ax=ax)
g = nx.draw_networkx_edges(
G,
pos=pos_dict,
node_size=[s * len(cells_size) * 300 for s in cells_size],
arrows=True,
arrowstyle="->",
arrowsize=20,
ax=ax,
)

cells_color_map["None"] = np.array([0, 0, 0, 1])
plt.legend(handles=[plt.Line2D([0], [0], marker="o", color='w', label=label,
markerfacecolor=color) for label, color in cells_color_map.items()],
loc="best",
fontsize="medium",
)

if save_show_or_return in ["save", "both", "all"]:
s_kwargs = {
"path": None,
"prefix": "plot_dim_reduced_direct_graph",
"dpi": None,
"ext": "pdf",
"transparent": True,
"close": True,
"verbose": True,
}
s_kwargs = update_dict(s_kwargs, save_kwargs)

if save_show_or_return in ["both", "all"]:
s_kwargs["close"] = False

save_fig(**s_kwargs)
if save_show_or_return in ["show", "both", "all"]:
plt.tight_layout()
plt.show()
if save_show_or_return in ["return", "all"]:
return g


def plot_direct_graph(
Expand Down
2 changes: 2 additions & 0 deletions dynamo/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ def is_list_of_lists(list_of_lists):

def get_color_map_from_labels(labels: np.ndarray, color_key_cmap: str = "glasbey_white") -> np.ndarray:
"""Generate a color map according to given labels.
Args:
labels: the label representing the groups of data.
color_key_cmap: the cmap used to generate the colors. Recommend 'glasbey_white'/'glasbey_black' for continuous
data, and 'inferno'/'viridis' for discrete data.
Returns:
The mapping of colors corresponding to each unique label.
"""
Expand Down
2 changes: 1 addition & 1 deletion dynamo/tools/DDRTree_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def DDRTree(
iterations.
"""

X = np.array(X)
X = np.array(X).T
(D, N) = X.shape

# initialization
Expand Down
3 changes: 3 additions & 0 deletions dynamo/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
)

# Pseudotime related
from .construct_velocity_tree import construct_velocity_tree, construct_velocity_tree_py
from .DDRTree_py import DDRTree, cal_ncenter
from .pseudotime import order_cells
from .time_series import directed_pg

# dimension reduction related
from .dimension_reduction import reduceDimension # , run_umap
Expand Down
Loading

0 comments on commit 4a5ac26

Please sign in to comment.