From 305b3c580df3350268e06f80cd0bba8d2511ed21 Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 7 Dec 2023 18:58:39 -0500 Subject: [PATCH 1/9] create the basic scatter func --- dynamo/plot/scatters.py | 764 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 764 insertions(+) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index e00b520d2..4709c922b 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -1363,3 +1363,767 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: save_show_or_return=save_show_or_return, save_kwargs=save_kwargs, ) + + +def scatters_single_input( + adata: AnnData, + basis: str = "umap", + x: int = 0, + y: int = 1, + z: int = 2, + color: str = "ntr", + layer: str = "X", + highlights: Optional[list] = None, + labels: Optional[list] = None, + values: Optional[list] = None, + theme: Optional[ + Literal[ + "blue", + "red", + "green", + "inferno", + "fire", + "viridis", + "darkblue", + "darkred", + "darkgreen", + ] + ] = None, + cmap: Optional[str] = None, + color_key: Union[Dict[str, str], List[str], None] = None, + color_key_cmap: Optional[str] = None, + background: Optional[str] = None, + ncols: int = 4, + pointsize: Optional[float] = None, + figsize: Tuple[float, float] = (6, 4), + show_legend: str = "on data", + use_smoothed: bool = True, + aggregate: Optional[str] = None, + show_arrowed_spines: bool = False, + ax: Optional[Axes] = None, + sort: Literal["raw", "abs", "neg"] = "raw", + save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", + save_kwargs: Dict[str, Any] = {}, + return_all: bool = False, + add_gamma_fit: bool = False, + frontier: bool = False, + contour: bool = False, + ccmap: Optional[str] = None, + alpha: float = 0.1, + calpha: float = 0.4, + sym_c: bool = False, + smooth: bool = False, + dpi: int = 100, + inset_dict: Dict[str, Any] = {}, + marker: Optional[str] = None, + group: Optional[str] = None, + add_group_gamma_fit: bool = False, + affine_transform_degree: Optional[int] = None, + affine_transform_A: Optional[float] = None, + affine_transform_b: Optional[float] = None, + stack_colors: bool = False, + stack_colors_threshold: float = 0.001, + stack_colors_title: str = "stacked colors", + stack_colors_legend_size: float = 2, + stack_colors_cmaps: Optional[List[str]] = None, + despline: bool = True, + deaxis: bool = True, + despline_sides: Optional[List[str]] = None, + projection: str = "2d", + **kwargs, +) -> Union[ + Axes, + List[Axes], + Tuple[Axes, List[str], Literal["white", "black"]], + Tuple[List[Axes], List[str], Literal["white", "black"]], + None, +]: + """Plot an embedding as points. Currently this only works for 2D embeddings. While there are many optional + parameters to further control and tailor the plotting, you need only pass in the trained/fit umap model to get + results. This plot utility will attempt to do the hard work of avoiding overplotting issues, and make it easy to + automatically color points by a categorical labelling or numeric values. This method is intended to be used within a + Jupyter notebook with `%matplotlib inline`. + + Args: + adata: an AnnData object. + basis: the reduced dimension stored in adata.obsm. The specific basis key will be constructed in the following + priority if exits: 1) specific layer input + basis 2) X_ + basis 3) basis. E.g. if basis is PCA, `scatters` + is going to look for 1) if specific layer is spliced, `spliced_pca` 2) `X_pca` (dynamo convention) 3) `pca`. + Defaults to "umap". + x: the column index of the low dimensional embedding for the x-axis. Defaults to 0. + y: the column index of the low dimensional embedding for the y-axis. Defaults to 1. + z: the column index of the low dimensional embedding for the z-axis. Defaults to 2. + color: any column names or gene expression, etc. that will be used for coloring cells. Defaults to "ntr". + layer: the layer of data to use for the scatter plot. Defaults to "X". + highlights: the color group that will be highlighted. If highligts is a list of lists, each list is relate to + each color element. Defaults to None. + labels: an array of labels (assumed integer or categorical), one for each data sample. This will be used for + coloring the points in the plot according to their label. Note that this option is mutually exclusive to the + `values` option. Defaults to None. + values: an array of values (assumed float or continuous), one for each sample. This will be used for coloring + the points in the plot according to a colorscale associated to the total range of values. Note that this + option is mutually exclusive to the `labels` option. Defaults to None. + theme: A color theme to use for plotting. A small set of predefined themes are provided which have relatively + good aesthetics. Available themes are: {'blue', 'red', 'green', 'inferno', 'fire', 'viridis', 'darkblue', + 'darkred', 'darkgreen'}. Defaults to None. + cmap: The name of a matplotlib colormap to use for coloring or shading points. If no labels or values are passed + this will be used for shading points according to density (largely only of relevance for very large + datasets). If values are passed this will be used for shading according the value. Note that if theme is + passed then this value will be overridden by the corresponding option of the theme. Defaults to None. + color_key: the method to assign colors to categoricals. This can either be an explicit dict mapping labels to + colors (as strings of form '#RRGGBB'), or an array like object providing one color for each distinct + category being provided in `labels`. Either way this mapping will be used to color points according to the + label. Note that if theme is passed then this value will be overridden by the corresponding option of the + theme. Defaults to None. + color_key_cmap: the name of a matplotlib colormap to use for categorical coloring. If an explicit `color_key` is + not given a color mapping for categories can be generated from the label list and selecting a matching list + of colors from the given colormap. Note that if theme is passed then this value will be overridden by the + corresponding option of the theme. Defaults to None. + background: the color of the background. Usually this will be either 'white' or 'black', but any color name will + work. Ideally one wants to match this appropriately to the colors being used for points etc. This is one of + the things that themes handle for you. Note that if theme is passed then this value will be overridden by + the corresponding option of the theme. Defaults to None. + ncols: the number of columns for the figure. Defaults to 4. + pointsize: the scale of the point size. Actual point cell size is calculated as + `500.0 / np.sqrt(adata.shape[0]) * pointsize`. Defaults to None. + figsize: the width and height of a figure. Defaults to (6, 4). + show_legend: whether to display a legend of the labels. Defaults to "on data". + use_smoothed: whether to use smoothed values (i.e. M_s / M_u instead of spliced / unspliced, etc.). Defaults to + True. + aggregate: the column in adata.obs that will be used to aggregate data points. Defaults to None. + show_arrowed_spines: whether to show a pair of arrowed spines representing the basis of the scatter is currently + using. Defaults to False. + ax: the matplotlib axes object where new plots will be added to. Only applicable to drawing a single component. + Defaults to None. + sort: the method to reorder data so that high values points will be on top of background points. Can be one of + {'raw', 'abs', 'neg'}, i.e. sorted by raw data, sort by absolute values or sort by negative values. Defaults + to "raw". + save_show_or_return: whether to save, show or return the figure. If "both", it will save and plot the figure at + the same time. If "all", the figure will be saved, displayed and the associated axis and other object will + be return. Defaults to "show". + save_kwargs: A dictionary that will passed to the save_fig function. By default it is an empty dictionary and + the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a + dictionary that properly modify those keys according to your needs. Defaults to {}. + return_all: whether to return all the scatter related variables. Defaults to False. + add_gamma_fit: whether to add the line of the gamma fitting. This will automatically turn on if `basis` points + to gene names and those genes have went through gamma fitting. Defaults to False. + frontier: whether to add the frontier. Scatter plots can be enhanced by using transparency (alpha) in order to + show area of high density and multiple scatter plots can be used to delineate a frontier. See matplotlib + tips & tricks cheatsheet (https://github.com/matplotlib/cheatsheets). Originally inspired by figures from + scEU-seq paper: https://science.sciencemag.org/content/367/6482/1151. If `contour` is set to be True, + `frontier` will be ignored as `contour` also add an outlier for data points. Defaults to False. + contour: whether to add an countor on top of scatter plots. We use tricontourf to plot contour for non-gridded + data. The shapely package was used to create a polygon of the concave hull of the scatters. With the polygon + we then check if the mean of the triangulated points are within the polygon and use this as our condition to + form the mask to create the contour. We also add the polygon shape as a frontier of the data point (similar + to when setting `frontier = True`). When the color of the data points is continuous, we will use the same + cmap as for the scatter points by default, when color is categorical, no contour will be drawn but just the + polygon. cmap can be set with `ccmap` argument. See below. This has recently changed to use seaborn's + kdeplot. Defaults to False. + ccmap: the name of a matplotlib colormap to use for coloring or shading points the contour. See above. + Defaults to None. + alpha: the point's alpha (transparency) value. Defaults to 0.1. + calpha: contour alpha value passed into sns.kdeplot. The value should be inbetween [0, 1]. Defaults to 0.4. + sym_c: whether do you want to make the limits of continuous color to be symmetric, normally this should be used + for plotting velocity, jacobian, curl, divergence or other types of data with both positive or negative + values. Defaults to False. + smooth: whether do you want to further smooth data and how much smoothing do you want. If it is `False`, no + smoothing will be applied. If `True`, smoothing based on one step diffusion of connectivity matrix + (`.uns['moment_cnn']`) will be applied. If a number larger than 1, smoothing will based on `smooth` steps of + diffusion. + dpi: the resolution of the figure in dots-per-inch. Dots per inches (dpi) determines how many pixels the figure + comprises. dpi is different from ppi or points per inches. Note that most elements like lines, markers, + texts have a size given in points so you can convert the points to inches. Matplotlib figures use Points per + inch (ppi) of 72. A line with thickness 1 point will be 1./72. inch wide. A text with fontsize 12 points + will be 12./72. inch heigh. Of course if you change the figure size in inches, points will not change, so a + larger figure in inches still has the same size of the elements.Changing the figure size is thus like taking + a piece of paper of a different size. Doing so, would of course not change the width of the line drawn with + the same pen. On the other hand, changing the dpi scales those elements. At 72 dpi, a line of 1 point size + is one pixel strong. At 144 dpi, this line is 2 pixels strong. A larger dpi will therefore act like a + magnifying glass. All elements are scaled by the magnifying power of the lens. see more details at answer 2 + by @ImportanceOfBeingErnest: + https://stackoverflow.com/questions/47633546/relationship-between-dpi-and-figure-size. Defaults to 100. + inset_dict: a dictionary of parameters in inset_ax. Example, something like {"width": "5%", "height": "50%", "loc": + 'lower left', "bbox_to_anchor": (0.85, 0.90, 0.145, 0.145), "bbox_transform": ax.transAxes, "borderpad": 0} + See more details at https://matplotlib.org/api/_as_gen/mpl_toolkits.axes_grid1.inset_locator.inset_axes.html + or https://stackoverflow.com/questions/39803385/what-does-a-4-element-tuple-argument-for-bbox-to-anchor-mean-in-matplotlib. + Defaults to {}. + marker: the marker style. marker can be either an instance of the class or the text shorthand for a particular + marker. See matplotlib.markers for more information about marker styles. Defaults to None. + group: the key in `adata.obs` corresponding to the cell group data. Defaults to None. + add_group_gamma_fit: whether to plot the cell group's gamma fit results. Defaults to False. + affine_transform_degree: transform coordinates of points according to some degree. Defaults to None. + affine_transform_A: coefficients in affine transformation Ax + b. 2D for now. Defaults to None. + affine_transform_b: bias in affine transformation Ax + b. Defaults to None. + stack_colors: whether to stack all color on the same ax passed above. Currently only support 18 sequential + matplotlib default cmaps assigning to different color groups. (#colors should be smaller than 18, reuse if + #colors > 18. TODO generate cmaps according to #colors). Defaults to False. + stack_colors_threshold: a threshold for filtering out points values < threshold when drawing each color. E.g. if + you do not want points with values < 1 showing up on axis, set threshold to be 1. Defaults to 0.001. + stack_colors_title: the title for the stack_color plot. Defaults to "stacked colors". + stack_colors_legend_size: the legend size in stack color plot. Defaults to 2. + stack_colors_cmaps: a list of cmaps that will be used to map values to color when stacking colors on the same + subplot. The order corresponds to the order of color. Defaults to None. + despline: whether to remove splines of the figure. Defaults to True. + deaxis: whether to remove axis ticks of the figure. Defaults to True. + despline_sides: which side of splines should be removed. Can be any combination of `["bottom", "right", "top", "left"]`. Defaults to None. + projection: the projection property of the matplotlib.Axes. Defaults to "2d". + **kwargs: any other kwargs that would be passed to `pyplot.scatters`. + + Raises: + ValueError: invalid adata object: lacking of required layers. + ValueError: `basis` not found in `adata.obsm`. + ValueError: invalid `x` or `y`. + ValueError: `labels` and `values` conflicted. + ValueError: invalid velocity estimation in `adata`. + ValueError: invalid velocity estimation in `adata`. + + Returns: + None would be returned by default. If `save_show_or_return` is set to be 'return' or 'all', the matplotlib axes + object of the generated plots would be returned. If `return_all` is set to be true, the list of colors used and + the font color would also be returned. + """ + + import matplotlib.pyplot as plt + + # 2d is not a projection in matplotlib, default is None (rectilinear) + if projection == "2d": + projection = None + if calpha < 0 or calpha > 1: + main_warning( + "calpha=%f is invalid (smaller than 0 or larger than 1) and may cause potential issues. Please check." + % (calpha) + ) + group_colors = ["b", "g", "r", "c", "m", "y", "k", "w"] + + if stack_colors and stack_colors_cmaps is None: + main_info("using default stack colors") + stack_colors_cmaps = [ + "Greys", + "Purples", + "Blues", + "Greens", + "Oranges", + "Reds", + "YlOrBr", + "YlOrRd", + "OrRd", + "PuRd", + "RdPu", + "BuPu", + "GnBu", + "PuBu", + "YlGnBu", + "PuBuGn", + "BuGn", + "YlGn", + ] + stack_legend_handles = [] + if stack_colors: + color_key = None + + if not (affine_transform_degree is None): + affine_transform_A = gen_rotation_2d(affine_transform_degree) + affine_transform_b = 0 + + if contour: + frontier = False + + if background is None: + _background = rcParams.get("figure.facecolor") + _background = to_hex(_background) if type(_background) is tuple else _background + # if save_show_or_return != 'save': set_figure_params('dynamo', background=_background) + else: + _background = background + # if save_show_or_return != 'save': set_figure_params('dynamo', background=_background) + + if is_gene_name(adata, basis): + if x not in ["M_s", "X_spliced", "M_t", "X_total", "spliced", "total"] and y not in [ + "M_u", + "X_unspliced", + "M_n", + "X_new", + "unspliced", + "new", + ]: + if "M_t" in adata.layers.keys() and "M_n" in adata.layers.keys(): + x, y = "M_t", "M_n" + elif "X_total" in adata.layers.keys() and "X_new" in adata.layers.keys(): + x, y = "X_total", "X_new" + elif "M_s" in adata.layers.keys() and "M_u" in adata.layers.keys(): + x, y = "M_s", "M_u" + elif "X_spliced" in adata.layers.keys() and "X_unspliced" in adata.layers.keys(): + x, y = "X_spliced", "X_unspliced" + elif "spliced" in adata.layers.keys() and "unspliced" in adata.layers.keys(): + x, y = "spliced", "unspliced" + elif "total" in adata.layers.keys() and "new" in adata.layers.keys(): + x, y = "total", "new" + else: + raise ValueError( + "your adata object is corrupted. Please make sure it has at least one of the following " + "pair of layers:" + "'M_s', 'X_spliced', 'M_t', 'X_total', 'spliced', 'total' and " + "'M_u', 'X_unspliced', 'M_n', 'X_new', 'unspliced', 'new'. " + ) + + if use_smoothed: + mapper = get_mapper() + + if pointsize is None: + point_size = 16000.0 / np.sqrt(adata.shape[0]) + else: + point_size = 16000.0 / np.sqrt(adata.shape[0]) * pointsize + + scatter_kwargs = dict( + alpha=alpha, + s=point_size, + edgecolor=None, + linewidth=0, + rasterized=True, + marker=marker, + ) # (0, 0, 0, 1) + if kwargs is not None: + scatter_kwargs.update(kwargs) + + font_color = _select_font_color(_background) + + if figsize is None: + figsize = plt.rcParams["figsize"] + + figure = None # possible as argument in future + + # if #total_panel is 1, `_matplotlib_points` will create a figure. No need to create a figure here and generate a blank figure. + if ax is None: + fig, ax = plt.subplots() + + color_out = None + + if layer in ["acceleration", "curvature", "divergence", "velocity_S", "velocity_T"]: + cur_l_smoothed = layer + cmap, sym_c = "bwr", True # TODO maybe use other divergent color map in the future + else: + if use_smoothed: + cur_l_smoothed = layer if layer.startswith("M_") | layer.startswith("velocity") else mapper[layer] + if layer.startswith("velocity"): + cmap, sym_c = "bwr", True + + if layer + "_" + basis in adata.obsm.keys(): + prefix = layer + "_" + elif ("X_" + basis) in adata.obsm.keys(): + prefix = "X_" + elif basis in adata.obsm.keys(): + # special case for spatial for compatibility with other packages + prefix = "" + else: + raise ValueError("Please check if basis=%s exists in adata.obsm" % basis) + + basis_key = prefix + basis + main_info("plotting with basis key=%s" % basis_key, indent_level=2) + + if stack_colors: + _stack_background_adata_indices = np.ones(len(adata), dtype=bool) + + main_debug("coloring scatter of cur_c: %s" % str(color)) + if not stack_colors: + cur_title = color + else: + cur_title = stack_colors_title + _color = _get_adata_color_vec(adata, layer, color) + + is_numeric_color = np.issubdtype(_color.dtype, np.number) + if not is_numeric_color: + main_info( + "skip filtering %s by stack threshold when stacking color because it is not a numeric type" + % (color), + indent_level=2, + ) + _values = values + + points, cur_title = _map_to_points( + adata, + axis_x=x, + axis_y=y, + axis_z=z, + basis_key=basis_key, + cur_c=color, + cur_b=basis, + cur_l_smoothed=cur_l_smoothed, + projection=projection, + ) + + if aggregate is not None: + groups, uniq_grp = ( + adata.obs[aggregate], + list(adata.obs[aggregate].unique()), + ) + group_color, group_median = ( + np.zeros((1, len(uniq_grp))).flatten() + if isinstance(_color[0], Number) + else np.zeros((1, len(uniq_grp))).astype("str").flatten(), + np.zeros((len(uniq_grp), 2)), + ) + + grp_size = adata.obs[aggregate].value_counts()[uniq_grp].values + scatter_kwargs = ( + {"s": grp_size} if scatter_kwargs is None else update_dict(scatter_kwargs, {"s": grp_size}) + ) + + for ind, cur_grp in enumerate(uniq_grp): + group_median[ind, :] = np.nanmedian( + points.iloc[np.where(groups == cur_grp)[0], :2], + 0, + ) + if isinstance(_color[0], Number): + group_color[ind] = np.nanmedian(np.array(_color)[np.where(groups == cur_grp)[0]]) + else: + group_color[ind] = pd.Series(_color)[np.where(groups == cur_grp)[0]].value_counts().index[0] + + points, _color = ( + pd.DataFrame( + group_median, + index=uniq_grp, + columns=points.columns, + ), + group_color, + ) + + is_not_continuous = not isinstance(_color[0], Number) or _color.dtype.name == "category" + + if is_not_continuous: + labels = np.asarray(_color) if is_categorical_dtype(_color) else _color + if theme is None: + if _background in ["#ffffff", "black"]: + _theme_ = "glasbey_dark" + else: + _theme_ = "glasbey_white" + else: + _theme_ = theme + else: + _values = _color + if theme is None: + if _background in ["#ffffff", "black"]: + _theme_ = "inferno" if layer != "velocity" else "div_blue_black_red" + else: + _theme_ = "viridis" if not layer.startswith("velocity") else "div_blue_red" + else: + _theme_ = theme + + _cmap = _themes[_theme_]["cmap"] if cmap is None else cmap + + _color_key_cmap = _themes[_theme_]["color_key_cmap"] if color_key_cmap is None else color_key_cmap + _background = _themes[_theme_]["background"] if _background is None else _background + + if labels is not None and values is not None: + raise ValueError("Conflicting options; only one of labels or values should be set") + + if highlights is not None: + _highlights = highlights if all([i in _color for i in highlights]) else None + + if smooth and not is_not_continuous: + main_debug("smooth and not continuous") + knn = adata.obsp["moments_con"] + values = ( + calc_1nd_moment(values, knn)[0] + if smooth in [1, True] + else calc_1nd_moment(values, knn ** smooth)[0] + ) + + if affine_transform_A is None or affine_transform_b is None: + point_coords = points.values + else: + point_coords = affine_transform(points.values, affine_transform_A, affine_transform_b) + + if points.shape[0] <= figsize[0] * figsize[1] * 100000: + main_debug("drawing with _matplotlib_points function") + ax, color_out = _matplotlib_points( + # points.values, + point_coords, + ax, + labels, + _values, + highlights, + _cmap, + color_key, + _color_key_cmap, + _background, + figsize[0], + figsize[1], + show_legend, + sort=sort, + frontier=frontier, + contour=contour, + ccmap=ccmap, + calpha=calpha, + sym_c=sym_c, + inset_dict=inset_dict, + projection=projection, + **scatter_kwargs, + ) + if labels is not None: + color_dict = {} + colors = [rgb2hex(i) for i in color_out] + for i, j in zip(labels, colors): + color_dict[i] = j + + adata.uns[cur_title + "_colors"] = color_dict + else: + main_debug("drawing with _datashade_points function") + ax = _datashade_points( + # points.values, + point_coords, + ax, + labels, + values, + highlights, + _cmap, + color_key, + _color_key_cmap, + _background, + figsize[0], + figsize[1], + show_legend, + sort=sort, + frontier=frontier, + contour=contour, + ccmap=ccmap, + calpha=calpha, + sym_c=sym_c, + **scatter_kwargs, + ) + + if show_arrowed_spines: + arrowed_spines(ax, points.columns[:2], _background) + else: + if despline: + despline_all(ax, despline_sides) + if deaxis: + deaxis_all(ax) + + ax.set_title(cur_title) + + if add_gamma_fit and basis in adata.var_names[adata.var.use_for_dynamics]: + xnew = np.linspace( + points.iloc[:, 0].min(), + points.iloc[:, 0].max() * 0.80, + ) + k_name = "gamma_k" if adata.uns["dynamics"]["experiment_type"] == "one-shot" else "gamma" + vel_params_df = get_vel_params(adata) + if k_name in vel_params_df.columns: + if not ("gamma_b" in vel_params_df.columns) or all(vel_params_df.gamma_b.isna()): + vel_params_df.loc[:, "gamma_b"] = 0 + update_vel_params(adata, params_df=vel_params_df) + ax.plot( + xnew, + xnew * adata[:, basis].var.loc[:, k_name].unique() + + adata[:, basis].var.loc[:, "gamma_b"].unique(), + dashes=[6, 2], + c=font_color, + ) + else: + raise ValueError( + "_adata does not seem to have %s column. Velocity estimation is required " + "before running this function." % k_name + ) + if group is not None and add_group_gamma_fit and basis in adata.var_names[adata.var.use_for_dynamics]: + cell_groups = adata.obs[group] + unique_groups = np.unique(cell_groups) + k_suffix = "gamma_k" if adata.uns["dynamics"]["experiment_type"] == "one-shot" else "gamma" + for group_idx, cur_group in enumerate(unique_groups): + group_k_name = group + "_" + cur_group + "_" + k_suffix + group_adata = adata[adata.obs[group] == cur_group] + group_points = points.iloc[np.array(adata.obs[group] == cur_group)] + group_b_key = group + "_" + cur_group + "_" + "gamma_b" + group_xnew = np.linspace( + group_points.iloc[:, 0].min(), + group_points.iloc[:, 0].max() * 0.90, + ) + group_ynew = ( + group_xnew * group_adata[:, basis].var.loc[:, group_k_name].unique() + + group_adata[:, basis].var.loc[:, group_b_key].unique() + ) + ax.annotate(group + "_" + cur_group, xy=(group_xnew[-1], group_ynew[-1])) + vel_params_df = get_vel_params(group_adata) + if group_k_name in vel_params_df.columns: + if not (group_b_key in vel_params_df.columns) or all(vel_params_df[group_b_key].isna()): + vel_params_df.loc[:, group_b_key] = 0 + update_vel_params(group_adata, params_df=vel_params_df) + main_info("No %s found, setting all bias terms to zero" % group_b_key) + ax.plot( + group_xnew, + group_ynew, + dashes=[6, 2], + c=group_colors[group_idx % len(group_colors)], + ) + else: + raise Exception( + "_adata does not seem to have %s column. Velocity estimation is required " + "before running this function." % group_k_name + ) + return ax + + +def _standardize_dimensions(x, y, z, n_obs, projection): + if projection == "3d": + x, y, z = ( + _standardize_dimension(x, n_obs), + _standardize_dimension(y, n_obs), + _standardize_dimension(z, n_obs), + ) + else: + x, y = _standardize_dimension(x, n_obs), _standardize_dimension(y, n_obs) + z = [np.nan] * len(x) + + assert len(x) == len(y) and len(x) == len(z), "bug: x, y, z does not have the same shape." + + +def _standardize_dimension(dim, n_obs): + if type(dim) in [anndata._core.views.ArrayView, np.ndarray] and len(dim) == n_obs: + return [dim] + elif hasattr(dim, "__len__"): + return list(dim) + + +def _map_to_points( + _adata: AnnData, + axis_x: str, + axis_y: str, + axis_z: str, + basis_key: str, + cur_c: str, + cur_b: str, + cur_l_smoothed: str, + projection: str, +) -> Tuple[pd.DataFrame, str]: + """A helper function to map the given axis to corresponding coordinates in current embedding space. + + Args: + _adata: an AnnData object. + axis_x: the column index of the low dimensional embedding for the x-axis in current space. + axis_y: the column index of the low dimensional embedding for the y-axis in current space. + axis_z: the column index of the low dimensional embedding for the z-axis in current space. + basis_key: the basis key constructed by current basis and layer. + cur_c: the current key to color the data. + cur_b: the current basis key representing the reduced dimension. + cur_l_smoothed: the smoothed layer of data to use. + + Returns: + The 3D DataFrame with coordinates of each sample and the title of the plot. + """ + gene_title = [] + anno_title = [] + + def _map_cur_axis_to_title( + cur, + _adata: AnnData, + cur_b: str, + cur_l_smoothed: str, + ) -> Tuple[np.ndarray, str]: + """A helper function to map an axis. + + Args: + cur: the current axis to map. + + Returns: + The coordinates and the column names. + """ + nonlocal gene_title, anno_title + + if is_gene_name(_adata, cur): + points_df_data = (_adata.obs_vector(k=cur, layer=None) + if cur_l_smoothed == "X" + else _adata.obs_vector(k=cur, layer=cur_l_smoothed)) + points_column = cur + " (" + cur_l_smoothed + ")" + gene_title.append(cur) + elif is_cell_anno_column(_adata, cur): + points_df_data = _adata.obs_vector(cur) + points_column = cur + anno_title.append(cur) + elif is_layer_keys(_adata, cur): + points_df_data = _adata[:, cur_b].layers[cur] + points_column = flatten(points_df_data) + else: + raise ValueError("Make sure your `x`, `y` and `z` are integers, gene names, column names in .obs, etc.") + + return points_df_data, points_column + + if projection == "3d": + if type(axis_x) is int and type(axis_y) is int and type(axis_z): + x_col_name = cur_b + "_0" + y_col_name = cur_b + "_1" + z_col_name = cur_b + "_2" + + points = pd.DataFrame( + { + x_col_name: _adata.obsm[basis_key][:, axis_x], + y_col_name: _adata.obsm[basis_key][:, axis_y], + z_col_name: _adata.obsm[basis_key][:, axis_z], + } + ) + points.columns = [x_col_name, y_col_name, z_col_name] + + cur_title = cur_c + + return points, cur_title + elif type(axis_x) in [anndata._core.views.ArrayView, np.ndarray] and type(axis_y) in [ + anndata._core.views.ArrayView, + np.ndarray, + ]: + points = pd.DataFrame({"x": flatten(axis_x), "y": flatten(axis_y), "z": flatten(axis_z)}) + points.columns = ["x", "y", "z"] + else: + x_points_df_data, x_points_column = _map_cur_axis_to_title(axis_x, _adata, cur_b, cur_l_smoothed) + y_points_df_data, y_points_column = _map_cur_axis_to_title(axis_y, _adata, cur_b, cur_l_smoothed) + z_points_df_data, z_points_column = _map_cur_axis_to_title(axis_z, _adata, cur_b, cur_l_smoothed) + points = pd.DataFrame({ + axis_x: x_points_df_data, + axis_y: y_points_df_data, + axis_z: z_points_df_data, + }) + points.columns = [x_points_column, y_points_column, z_points_column] + + if len(gene_title) != 0: + cur_title = " VS ".join(gene_title) + elif len(anno_title) == 3: + cur_title = " VS ".join(anno_title) + else: + cur_title = cur_b + else: + if type(axis_x) is int and type(axis_y) is int: + x_col_name = cur_b + "_0" + y_col_name = cur_b + "_1" + + points = pd.DataFrame( + { + x_col_name: _adata.obsm[basis_key][:, axis_x], + y_col_name: _adata.obsm[basis_key][:, axis_y], + } + ) + points.columns = [x_col_name, y_col_name] + + cur_title = cur_c + + return points, cur_title + elif type(axis_x) in [anndata._core.views.ArrayView, np.ndarray] and type(axis_y) in [ + anndata._core.views.ArrayView, + np.ndarray, + ]: + points = pd.DataFrame({"x": flatten(axis_x), "y": flatten(axis_y)}) + points.columns = ["x", "y"] + else: + x_points_df_data, x_points_column = _map_cur_axis_to_title(axis_x, _adata, cur_b, cur_l_smoothed) + y_points_df_data, y_points_column = _map_cur_axis_to_title(axis_y, _adata, cur_b, cur_l_smoothed) + points = pd.DataFrame({ + axis_x: x_points_df_data, + axis_y: y_points_df_data, + }) + points.columns = [x_points_column, y_points_column] + + if len(gene_title) != 0: + cur_title = " VS ".join(gene_title) + elif len(anno_title) == 2: + cur_title = " VS ".join(anno_title) + else: + cur_title = cur_b + + return points, cur_title From 424fee5cfb31c727713c5c1bc1cd6f44c7dcdbd2 Mon Sep 17 00:00:00 2001 From: sichao Date: Fri, 8 Dec 2023 14:38:05 -0500 Subject: [PATCH 2/9] create _get_color_parameters for scatter plot --- dynamo/plot/scatters.py | 200 +++++++++++++++++++++------------------- 1 file changed, 107 insertions(+), 93 deletions(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 4709c922b..edfe6483a 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -1595,33 +1595,6 @@ def scatters_single_input( "calpha=%f is invalid (smaller than 0 or larger than 1) and may cause potential issues. Please check." % (calpha) ) - group_colors = ["b", "g", "r", "c", "m", "y", "k", "w"] - - if stack_colors and stack_colors_cmaps is None: - main_info("using default stack colors") - stack_colors_cmaps = [ - "Greys", - "Purples", - "Blues", - "Greens", - "Oranges", - "Reds", - "YlOrBr", - "YlOrRd", - "OrRd", - "PuRd", - "RdPu", - "BuPu", - "GnBu", - "PuBu", - "YlGnBu", - "PuBuGn", - "BuGn", - "YlGn", - ] - stack_legend_handles = [] - if stack_colors: - color_key = None if not (affine_transform_degree is None): affine_transform_A = gen_rotation_2d(affine_transform_degree) @@ -1630,14 +1603,6 @@ def scatters_single_input( if contour: frontier = False - if background is None: - _background = rcParams.get("figure.facecolor") - _background = to_hex(_background) if type(_background) is tuple else _background - # if save_show_or_return != 'save': set_figure_params('dynamo', background=_background) - else: - _background = background - # if save_show_or_return != 'save': set_figure_params('dynamo', background=_background) - if is_gene_name(adata, basis): if x not in ["M_s", "X_spliced", "M_t", "X_total", "spliced", "total"] and y not in [ "M_u", @@ -1686,16 +1651,12 @@ def scatters_single_input( if kwargs is not None: scatter_kwargs.update(kwargs) - font_color = _select_font_color(_background) - if figsize is None: figsize = plt.rcParams["figsize"] - figure = None # possible as argument in future - # if #total_panel is 1, `_matplotlib_points` will create a figure. No need to create a figure here and generate a blank figure. if ax is None: - fig, ax = plt.subplots() + figure, ax = plt.subplots() color_out = None @@ -1719,26 +1680,9 @@ def scatters_single_input( raise ValueError("Please check if basis=%s exists in adata.obsm" % basis) basis_key = prefix + basis - main_info("plotting with basis key=%s" % basis_key, indent_level=2) - - if stack_colors: - _stack_background_adata_indices = np.ones(len(adata), dtype=bool) + main_info("plotting with basis key=%s" % basis_key, indent_level=2) main_debug("coloring scatter of cur_c: %s" % str(color)) - if not stack_colors: - cur_title = color - else: - cur_title = stack_colors_title - _color = _get_adata_color_vec(adata, layer, color) - - is_numeric_color = np.issubdtype(_color.dtype, np.number) - if not is_numeric_color: - main_info( - "skip filtering %s by stack threshold when stacking color because it is not a numeric type" - % (color), - indent_level=2, - ) - _values = values points, cur_title = _map_to_points( adata, @@ -1752,6 +1696,30 @@ def scatters_single_input( projection=projection, ) + ( + _color, + labels, + values, + _theme_, + cmap, + color_key_cmap, + background, + font_color, + is_not_continuous, + cur_title, + ) = _get_color_parameters( + adata=adata, + color=color, + layer=layer, + labels=labels, + values=values, + theme=theme, + cmap=cmap, + color_key_cmap=color_key_cmap, + background=background, + title=cur_title, + ) + if aggregate is not None: groups, uniq_grp = ( adata.obs[aggregate], @@ -1788,34 +1756,10 @@ def scatters_single_input( group_color, ) - is_not_continuous = not isinstance(_color[0], Number) or _color.dtype.name == "category" - - if is_not_continuous: - labels = np.asarray(_color) if is_categorical_dtype(_color) else _color - if theme is None: - if _background in ["#ffffff", "black"]: - _theme_ = "glasbey_dark" - else: - _theme_ = "glasbey_white" + if is_not_continuous: + labels = _color else: - _theme_ = theme - else: - _values = _color - if theme is None: - if _background in ["#ffffff", "black"]: - _theme_ = "inferno" if layer != "velocity" else "div_blue_black_red" - else: - _theme_ = "viridis" if not layer.startswith("velocity") else "div_blue_red" - else: - _theme_ = theme - - _cmap = _themes[_theme_]["cmap"] if cmap is None else cmap - - _color_key_cmap = _themes[_theme_]["color_key_cmap"] if color_key_cmap is None else color_key_cmap - _background = _themes[_theme_]["background"] if _background is None else _background - - if labels is not None and values is not None: - raise ValueError("Conflicting options; only one of labels or values should be set") + values = _color if highlights is not None: _highlights = highlights if all([i in _color for i in highlights]) else None @@ -1841,12 +1785,12 @@ def scatters_single_input( point_coords, ax, labels, - _values, + values, highlights, - _cmap, + cmap, color_key, - _color_key_cmap, - _background, + color_key_cmap, + background, figsize[0], figsize[1], show_legend, @@ -1876,10 +1820,10 @@ def scatters_single_input( labels, values, highlights, - _cmap, + cmap, color_key, - _color_key_cmap, - _background, + color_key_cmap, + background, figsize[0], figsize[1], show_legend, @@ -1893,7 +1837,7 @@ def scatters_single_input( ) if show_arrowed_spines: - arrowed_spines(ax, points.columns[:2], _background) + arrowed_spines(ax, points.columns[:2], background) else: if despline: despline_all(ax, despline_sides) @@ -1926,6 +1870,7 @@ def scatters_single_input( "before running this function." % k_name ) if group is not None and add_group_gamma_fit and basis in adata.var_names[adata.var.use_for_dynamics]: + group_colors = ["b", "g", "r", "c", "m", "y", "k", "w"] cell_groups = adata.obs[group] unique_groups = np.unique(cell_groups) k_suffix = "gamma_k" if adata.uns["dynamics"]["experiment_type"] == "one-shot" else "gamma" @@ -1963,6 +1908,75 @@ def scatters_single_input( return ax +def _get_color_parameters(adata, color, layer, labels, values, theme, cmap, color_key_cmap, background, title): + if background is None: + background = rcParams.get("figure.facecolor") + background = to_hex(background) if type(background) is tuple else background + # if save_show_or_return != 'save': set_figure_params('dynamo', background=_background) + + font_color = _select_font_color(background) + + def _get_theme_for_labels(): + if theme is None: + if background in ["#ffffff", "black"]: + _theme_ = "glasbey_dark" + else: + _theme_ = "glasbey_white" + else: + _theme_ = theme + + return _theme_ + + def _get_theme_for_values(): + if theme is None: + if background in ["#ffffff", "black"]: + _theme_ = "inferno" if layer != "velocity" else "div_blue_black_red" + else: + _theme_ = "viridis" if not layer.startswith("velocity") else "div_blue_red" + else: + _theme_ = theme + + return _theme_ + + if labels is not None and values is not None: + raise ValueError("Conflicting options; only one of labels or values should be set") + + if labels is not None: + main_info("labels input will disable the color by %s" % color) + _color = labels + title = "customized labels" if title == color else title + elif values is not None: + main_info("values input will disable the color by %s" % color) + _color = values + title = "customized values" if title == color else title + else: + _color = _get_adata_color_vec(adata, layer, color) + + is_numeric_color = np.issubdtype(_color.dtype, np.number) + if not is_numeric_color: + main_info( + "skip filtering %s by stack threshold when stacking color because it is not a numeric type" + % color, + indent_level=2, + ) + + is_not_continuous = not isinstance(_color[0], Number) or _color.dtype.name == "category" + + if is_not_continuous: + labels = np.asarray(_color) if is_categorical_dtype(_color) else _color + _theme_ = _get_theme_for_labels() + else: + values = _color + _theme_ = _get_theme_for_values() + + cmap = _themes[_theme_]["cmap"] if cmap is None else cmap + + color_key_cmap = _themes[_theme_]["color_key_cmap"] if color_key_cmap is None else color_key_cmap + background = _themes[_theme_]["background"] if background is None else background + + return _color, labels, values, _theme_, cmap, color_key_cmap, background, font_color, is_not_continuous, title + + def _standardize_dimensions(x, y, z, n_obs, projection): if projection == "3d": x, y, z = ( From f900ea30fb73c994c3c49807563e7ea62c16acc7 Mon Sep 17 00:00:00 2001 From: sichao Date: Fri, 8 Dec 2023 18:46:12 -0500 Subject: [PATCH 3/9] create new scatters --- dynamo/plot/scatters.py | 328 ++++++++++++++++++++++++++++++++++------ 1 file changed, 285 insertions(+), 43 deletions(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index edfe6483a..6c9d76b58 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -18,6 +18,7 @@ from matplotlib.lines import Line2D from matplotlib.colors import rgb2hex, to_hex from pandas.api.types import is_categorical_dtype +from scipy.sparse import issparse from ..configuration import _themes, reset_rcParams from ..docrep import DocstringProcessor @@ -1365,6 +1366,280 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: ) +def scatters2( + adata: AnnData, + basis: str = "umap", + x: int = 0, + y: int = 1, + z: int = 2, + color: str = "ntr", + layer: str = "X", + highlights: Optional[list] = None, + labels: Optional[list] = None, + values: Optional[list] = None, + theme: Optional[ + Literal[ + "blue", + "red", + "green", + "inferno", + "fire", + "viridis", + "darkblue", + "darkred", + "darkgreen", + ] + ] = None, + cmap: Optional[str] = None, + color_key: Union[Dict[str, str], List[str], None] = None, + color_key_cmap: Optional[str] = None, + background: Optional[str] = None, + ncols: int = 4, + pointsize: Optional[float] = None, + figsize: Tuple[float, float] = (6, 4), + show_legend: str = "on data", + use_smoothed: bool = True, + aggregate: Optional[str] = None, + show_arrowed_spines: bool = False, + ax: Optional[Axes] = None, + sort: Literal["raw", "abs", "neg"] = "raw", + save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", + save_kwargs: Dict[str, Any] = {}, + return_all: bool = False, + add_gamma_fit: bool = False, + frontier: bool = False, + contour: bool = False, + ccmap: Optional[str] = None, + alpha: float = 0.1, + calpha: float = 0.4, + sym_c: bool = False, + smooth: bool = False, + dpi: int = 100, + inset_dict: Dict[str, Any] = {}, + marker: Optional[str] = None, + group: Optional[str] = None, + add_group_gamma_fit: bool = False, + affine_transform_degree: Optional[int] = None, + affine_transform_A: Optional[float] = None, + affine_transform_b: Optional[float] = None, + stack_colors: bool = False, + stack_colors_threshold: float = 0.001, + stack_colors_title: str = "stacked colors", + stack_colors_legend_size: float = 2, + stack_colors_cmaps: Optional[List[str]] = None, + despline: bool = True, + deaxis: bool = True, + despline_sides: Optional[List[str]] = None, + projection: str = "2d", + **kwargs, +) -> Union[ + Axes, + List[Axes], + Tuple[Axes, List[str], Literal["white", "black"]], + Tuple[List[Axes], List[str], Literal["white", "black"]], + None, +]: + import matplotlib.pyplot as plt + + if projection == "2d": + projection = None + + if calpha < 0 or calpha > 1: + main_warning( + "calpha=%f is invalid (smaller than 0 or larger than 1) and may cause potential issues. Please check." + % calpha + ) + + if background is None: + background = rcParams.get("figure.facecolor") + background = to_hex(background) if type(background) is tuple else background + # if save_show_or_return != 'save': set_figure_params('dynamo', background=_background) + + if not (affine_transform_degree is None): + affine_transform_A = gen_rotation_2d(affine_transform_degree) + affine_transform_b = 0 + + if contour: + frontier = False + + if pointsize is None: + point_size = 16000.0 / np.sqrt(adata.shape[0]) + else: + point_size = 16000.0 / np.sqrt(adata.shape[0]) * pointsize + + scatter_kwargs = dict( + alpha=alpha, + s=point_size, + edgecolor=None, + linewidth=0, + rasterized=True, + marker=marker, + ) # (0, 0, 0, 1) + if kwargs is not None: + scatter_kwargs.update(kwargs) + + if figsize is None: + figsize = plt.rcParams["figsize"] + + if type(x) in [int, str]: + x = [x] + if type(y) in [int, str]: + y = [y] + if type(z) in [int, str]: + z = [z] + + if type(color) is str: + color = [color] + if type(layer) is str: + layer = [layer] + if type(basis) is str: + basis = [basis] + + n_c, n_l, n_b, n_x, n_y = ( + 1 if color is None else len(color), + 1 if layer is None else len(layer), + 1 if basis is None else len(basis), + 1 if x is None else 1 if type(x) in [anndata._core.views.ArrayView, np.ndarray] else len(x), + # check whether it is an array + 1 if y is None else 1 if type(y) in [anndata._core.views.ArrayView, np.ndarray] else len(y), + # check whether it is an array + ) + + total_panels, ncols = ( + n_c * n_l * n_b * n_x * n_y, + min(max([n_c, n_l, n_b, n_x, n_y]), ncols), + ) + nrow, ncol = int(np.ceil(total_panels / ncols)), ncols + + if total_panels > 1 and ax is None: + figure = plt.figure( + None, + (figsize[0] * ncol, figsize[1] * nrow), + facecolor=background, + dpi=dpi, + ) + gs = plt.GridSpec(nrow, ncol, wspace=0.12) + + ax_index = 0 + axes_list, color_list = [], [] + color_out = None + + for cur_b in basis: + for cur_l in layer: + main_debug("Plotting basis:%s, layer: %s" % (str(basis), str(layer))) + main_debug("colors: %s" % (str(color))) + + for cur_c in color: + x, y, z = _standardize_dimensions(x, y, z, adata.n_obs, projection) + + for cur_x, cur_y, cur_z in zip(x, y, z): # here x / y are arrays + + if total_panels > 1 and not stack_colors: + ax = plt.subplot(gs[ax_index], projection=projection) + ax_index += 1 + + ax, color_out, font_color = scatters_single_input( + adata=adata, + basis=cur_b, + x=cur_x, + y=cur_y, + z=cur_z, + color=cur_c, + layer=cur_l, + highlights=highlights, + labels=labels, + values=values, + theme=theme, + cmap=cmap, + color_key=color_key, + color_key_cmap=color_key_cmap, + background=background, + ncols=ncols, + pointsize=pointsize, + figsize=figsize, + show_legend=show_legend, + use_smoothed=use_smoothed, + aggregate=aggregate, + show_arrowed_spines=show_arrowed_spines, + ax=ax, + sort=sort, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + return_all=return_all, + add_gamma_fit=add_gamma_fit, + frontier=frontier, + contour=contour, + ccmap=ccmap, + alpha=alpha, + calpha=calpha, + sym_c=sym_c, + smooth=smooth, + dpi=dpi, + inset_dict=inset_dict, + marker=marker, + group=group, + add_group_gamma_fit=add_group_gamma_fit, + affine_transform_degree=affine_transform_degree, + affine_transform_A=affine_transform_A, + affine_transform_b=affine_transform_b, + stack_colors=stack_colors, + stack_colors_threshold=stack_colors_threshold, + stack_colors_title=stack_colors_title, + stack_colors_legend_size=stack_colors_legend_size, + stack_colors_cmaps=stack_colors_cmaps, + despline=despline, + deaxis=deaxis, + despline_sides=despline_sides, + projection=projection, + scatter_kwargs=scatter_kwargs, + ** kwargs, + ) + + axes_list.append(ax) + color_list.append(color_out) + + main_debug("show, return or save...") + if save_show_or_return in ["save", "both", "all"]: + s_kwargs = { + "path": None, + "prefix": "scatters", + "dpi": None, + "ext": "pdf", + "transparent": True, + "close": True, + "verbose": True, + } + + # prevent the plot from being closed if the plot need to be shown or returned. + if save_show_or_return in ["both", "all"]: + s_kwargs["close"] = False + + s_kwargs = update_dict(s_kwargs, save_kwargs) + + save_fig(**s_kwargs) + if background is not None: + reset_rcParams() + if save_show_or_return in ["show", "both", "all"]: + if show_legend: + plt.subplots_adjust(right=0.85) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + plt.tight_layout() + + plt.show() + if background is not None: + reset_rcParams() + if save_show_or_return in ["return", "all"]: + if background is not None: + reset_rcParams() + + if return_all: + return (axes_list, color_list, font_color) if total_panels > 1 else (ax, color_out, font_color) + else: + return axes_list if total_panels > 1 else ax + + def scatters_single_input( adata: AnnData, basis: str = "umap", @@ -1430,6 +1705,7 @@ def scatters_single_input( deaxis: bool = True, despline_sides: Optional[List[str]] = None, projection: str = "2d", + scatter_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> Union[ Axes, @@ -1584,25 +1860,8 @@ def scatters_single_input( object of the generated plots would be returned. If `return_all` is set to be true, the list of colors used and the font color would also be returned. """ - import matplotlib.pyplot as plt - # 2d is not a projection in matplotlib, default is None (rectilinear) - if projection == "2d": - projection = None - if calpha < 0 or calpha > 1: - main_warning( - "calpha=%f is invalid (smaller than 0 or larger than 1) and may cause potential issues. Please check." - % (calpha) - ) - - if not (affine_transform_degree is None): - affine_transform_A = gen_rotation_2d(affine_transform_degree) - affine_transform_b = 0 - - if contour: - frontier = False - if is_gene_name(adata, basis): if x not in ["M_s", "X_spliced", "M_t", "X_total", "spliced", "total"] and y not in [ "M_u", @@ -1635,25 +1894,6 @@ def scatters_single_input( if use_smoothed: mapper = get_mapper() - if pointsize is None: - point_size = 16000.0 / np.sqrt(adata.shape[0]) - else: - point_size = 16000.0 / np.sqrt(adata.shape[0]) * pointsize - - scatter_kwargs = dict( - alpha=alpha, - s=point_size, - edgecolor=None, - linewidth=0, - rasterized=True, - marker=marker, - ) # (0, 0, 0, 1) - if kwargs is not None: - scatter_kwargs.update(kwargs) - - if figsize is None: - figsize = plt.rcParams["figsize"] - # if #total_panel is 1, `_matplotlib_points` will create a figure. No need to create a figure here and generate a blank figure. if ax is None: figure, ax = plt.subplots() @@ -1676,6 +1916,8 @@ def scatters_single_input( elif basis in adata.obsm.keys(): # special case for spatial for compatibility with other packages prefix = "" + elif is_gene_name(adata, basis): + prefix = "" else: raise ValueError("Please check if basis=%s exists in adata.obsm" % basis) @@ -1905,14 +2147,10 @@ def scatters_single_input( "_adata does not seem to have %s column. Velocity estimation is required " "before running this function." % group_k_name ) - return ax + return ax, color_out, font_color def _get_color_parameters(adata, color, layer, labels, values, theme, cmap, color_key_cmap, background, title): - if background is None: - background = rcParams.get("figure.facecolor") - background = to_hex(background) if type(background) is tuple else background - # if save_show_or_return != 'save': set_figure_params('dynamo', background=_background) font_color = _select_font_color(background) @@ -1990,6 +2228,8 @@ def _standardize_dimensions(x, y, z, n_obs, projection): assert len(x) == len(y) and len(x) == len(z), "bug: x, y, z does not have the same shape." + return x, y, z + def _standardize_dimension(dim, n_obs): if type(dim) in [anndata._core.views.ArrayView, np.ndarray] and len(dim) == n_obs: @@ -2127,11 +2367,13 @@ def _map_cur_axis_to_title( else: x_points_df_data, x_points_column = _map_cur_axis_to_title(axis_x, _adata, cur_b, cur_l_smoothed) y_points_df_data, y_points_column = _map_cur_axis_to_title(axis_y, _adata, cur_b, cur_l_smoothed) + x_points_df_data = x_points_df_data.A.flatten() if issparse(x_points_df_data) else x_points_df_data + y_points_df_data = y_points_df_data.A.flatten() if issparse(y_points_df_data) else y_points_df_data points = pd.DataFrame({ axis_x: x_points_df_data, axis_y: y_points_df_data, - }) - points.columns = [x_points_column, y_points_column] + }, index=_adata.obs_names) + points.columns = [axis_x, axis_y] if len(gene_title) != 0: cur_title = " VS ".join(gene_title) From 47dca5308014a3acad1c0f3a83da1bb415f875c8 Mon Sep 17 00:00:00 2001 From: sichao Date: Wed, 17 Jan 2024 17:02:34 -0500 Subject: [PATCH 4/9] update new scatters with stack feature --- dynamo/plot/scatters.py | 136 ++++++++++++++++++++++------------------ 1 file changed, 74 insertions(+), 62 deletions(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 68acd27d8..a508e9710 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -1417,6 +1417,32 @@ def scatters2( % calpha ) + if stack_colors and stack_colors_cmaps is None: + main_info("using default stack colors") + stack_colors_cmaps = [ + "Greys", + "Purples", + "Blues", + "Greens", + "Oranges", + "Reds", + "YlOrBr", + "YlOrRd", + "OrRd", + "PuRd", + "RdPu", + "BuPu", + "GnBu", + "PuBu", + "YlGnBu", + "PuBuGn", + "BuGn", + "YlGn", + ] + stack_legend_handles = [] + if stack_colors: + color_key = None + if background is None: background = rcParams.get("figure.facecolor") background = to_hex(background) if type(background) is tuple else background @@ -1491,6 +1517,9 @@ def scatters2( axes_list, color_list = [], [] color_out = None + if stack_colors: + ax = None + for cur_b in basis: for cur_l in layer: main_debug("Plotting basis:%s, layer: %s" % (str(basis), str(layer))) @@ -1505,7 +1534,7 @@ def scatters2( ax = plt.subplot(gs[ax_index], projection=projection) ax_index += 1 - ax, color_out, font_color = scatters_single_input( + ax, color_out, font_color, stack_legend_handles = scatters_single_input( adata=adata, basis=cur_b, x=cur_x, @@ -1521,8 +1550,6 @@ def scatters2( color_key=color_key, color_key_cmap=color_key_cmap, background=background, - ncols=ncols, - pointsize=pointsize, figsize=figsize, show_legend=show_legend, use_smoothed=use_smoothed, @@ -1530,23 +1557,16 @@ def scatters2( show_arrowed_spines=show_arrowed_spines, ax=ax, sort=sort, - save_show_or_return=save_show_or_return, - save_kwargs=save_kwargs, - return_all=return_all, add_gamma_fit=add_gamma_fit, frontier=frontier, contour=contour, ccmap=ccmap, - alpha=alpha, calpha=calpha, sym_c=sym_c, smooth=smooth, - dpi=dpi, inset_dict=inset_dict, - marker=marker, group=group, add_group_gamma_fit=add_group_gamma_fit, - affine_transform_degree=affine_transform_degree, affine_transform_A=affine_transform_A, affine_transform_b=affine_transform_b, stack_colors=stack_colors, @@ -1554,57 +1574,25 @@ def scatters2( stack_colors_title=stack_colors_title, stack_colors_legend_size=stack_colors_legend_size, stack_colors_cmaps=stack_colors_cmaps, + stack_legend_handles=stack_legend_handles, despline=despline, deaxis=deaxis, despline_sides=despline_sides, projection=projection, scatter_kwargs=scatter_kwargs, - ** kwargs, + ax_index=ax_index, ) axes_list.append(ax) color_list.append(color_out) main_debug("show, return or save...") - if save_show_or_return in ["save", "both", "all"]: - s_kwargs = { - "path": None, - "prefix": "scatters", - "dpi": None, - "ext": "pdf", - "transparent": True, - "close": True, - "verbose": True, - } - - # prevent the plot from being closed if the plot need to be shown or returned. - if save_show_or_return in ["both", "all"]: - s_kwargs["close"] = False - - s_kwargs = update_dict(s_kwargs, save_kwargs) - - save_fig(**s_kwargs) - if background is not None: - reset_rcParams() - if save_show_or_return in ["show", "both", "all"]: - if show_legend: - plt.subplots_adjust(right=0.85) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - plt.tight_layout() - - plt.show() - if background is not None: - reset_rcParams() - if save_show_or_return in ["return", "all"]: - if background is not None: - reset_rcParams() - - if return_all: - return (axes_list, color_list, font_color) if total_panels > 1 else (ax, color_out, font_color) - else: - return axes_list if total_panels > 1 else ax + if return_all: + return_value = (axes_list, color_list, font_color) if total_panels > 1 else (ax, color_out, font_color) + else: + return_value = axes_list if total_panels > 1 else ax + return save_show_ret("scatters", save_show_or_return, save_kwargs, return_value, adjust=show_legend, + background=background) def scatters_single_input( @@ -1635,8 +1623,6 @@ def scatters_single_input( color_key: Union[Dict[str, str], List[str], None] = None, color_key_cmap: Optional[str] = None, background: Optional[str] = None, - ncols: int = 4, - pointsize: Optional[float] = None, figsize: Tuple[float, float] = (6, 4), show_legend: str = "on data", use_smoothed: bool = True, @@ -1644,23 +1630,16 @@ def scatters_single_input( show_arrowed_spines: bool = False, ax: Optional[Axes] = None, sort: Literal["raw", "abs", "neg"] = "raw", - save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", - save_kwargs: Dict[str, Any] = {}, - return_all: bool = False, add_gamma_fit: bool = False, frontier: bool = False, contour: bool = False, ccmap: Optional[str] = None, - alpha: float = 0.1, calpha: float = 0.4, sym_c: bool = False, smooth: bool = False, - dpi: int = 100, inset_dict: Dict[str, Any] = {}, - marker: Optional[str] = None, group: Optional[str] = None, add_group_gamma_fit: bool = False, - affine_transform_degree: Optional[int] = None, affine_transform_A: Optional[float] = None, affine_transform_b: Optional[float] = None, stack_colors: bool = False, @@ -1668,12 +1647,13 @@ def scatters_single_input( stack_colors_title: str = "stacked colors", stack_colors_legend_size: float = 2, stack_colors_cmaps: Optional[List[str]] = None, + stack_legend_handles: Optional[List[Line2D]] = None, despline: bool = True, deaxis: bool = True, despline_sides: Optional[List[str]] = None, projection: str = "2d", scatter_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, + ax_index: Optional[int] = None, ) -> Union[ Axes, List[Axes], @@ -1927,8 +1907,28 @@ def scatters_single_input( color_key_cmap=color_key_cmap, background=background, title=cur_title, + stack_colors=stack_colors, + stack_colors_threshold=stack_colors_threshold, ) + if stack_colors: + main_debug("stack colors: changing cmap") + cur_title = stack_colors_title + cmap = stack_colors_cmaps[(ax_index - 1) % len(stack_colors_cmaps)] + max_color = matplotlib.cm.get_cmap(cmap)(float("inf")) + # TODO: consider remove the legend because it is not helpful + legend_circle = Line2D( + [0], + [0], + marker="o", + color="w", + markerfacecolor=max_color, + label=color, + markersize=stack_colors_legend_size, + ) + stack_legend_handles.append(legend_circle) + ax.legend(handles=stack_legend_handles, loc="upper right", prop={"size": stack_colors_legend_size}) + if aggregate is not None: groups, uniq_grp = ( adata.obs[aggregate], @@ -2114,10 +2114,10 @@ def scatters_single_input( "_adata does not seem to have %s column. Velocity estimation is required " "before running this function." % group_k_name ) - return ax, color_out, font_color + return ax, color_out, font_color, stack_legend_handles -def _get_color_parameters(adata, color, layer, labels, values, theme, cmap, color_key_cmap, background, title): +def _get_color_parameters(adata, color, layer, labels, values, theme, cmap, color_key_cmap, background, title, stack_colors, stack_colors_threshold): font_color = _select_font_color(background) @@ -2165,6 +2165,18 @@ def _get_theme_for_values(): indent_level=2, ) + if stack_colors and np.issubdtype(_color.dtype, np.number): + main_debug("Subsetting adata by stack_colors") + _stack_background_adata_indices = np.ones(len(adata), dtype=bool) + _stack_background_adata_indices = np.logical_and( + _stack_background_adata_indices, (_color < stack_colors_threshold) + ) + if values is not None: + values = values[_color > stack_colors_threshold] + _color = _color[_color > stack_colors_threshold] + if len(_color) == 0: + main_info("skipping filtering %s because no point of %s is above threshold" % (color, color)) + is_not_continuous = not isinstance(_color[0], Number) or _color.dtype.name == "category" if is_not_continuous: From b6a21ee3021b6317df283527823413c45e8a65e3 Mon Sep 17 00:00:00 2001 From: sichao Date: Thu, 18 Jan 2024 13:02:53 -0500 Subject: [PATCH 5/9] rename funcs and update docstr --- dynamo/plot/scatters.py | 227 +++++++++++++++++++++++++++++++--------- 1 file changed, 177 insertions(+), 50 deletions(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index a508e9710..b25ebf4b6 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -49,7 +49,7 @@ @docstrings.get_sectionsf("scatters") -def scatters( +def scatters_legacy( adata: AnnData, basis: str = "umap", x: int = 0, @@ -1333,7 +1333,7 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: ) -def scatters2( +def scatters( adata: AnnData, basis: str = "umap", x: int = 0, @@ -1406,6 +1406,153 @@ def scatters2( Tuple[List[Axes], List[str], Literal["white", "black"]], None, ]: + """Plot an embedding as points. + + Support both 2D and 3D embeddings. While there are many optional parameters to further control and tailor the + plotting, you need only pass in the trained/fit umap model to get results. This plot utility will attempt to do the + hard work of avoiding overplotting issues, and make it easy to automatically color points by a categorical labelling + or numeric values. This method is intended to be used within a Jupyter notebook with `%matplotlib inline`. + + Args: + adata: an AnnData object. + basis: the reduced dimension stored in adata.obsm. The specific basis key will be constructed in the following + priority if exits: 1) specific layer input + basis 2) X_ + basis 3) basis. E.g. if basis is PCA, `scatters` + is going to look for 1) if specific layer is spliced, `spliced_pca` 2) `X_pca` (dynamo convention) 3) `pca`. + Defaults to "umap". + x: the column index of the low dimensional embedding for the x-axis. Defaults to 0. + y: the column index of the low dimensional embedding for the y-axis. Defaults to 1. + z: the column index of the low dimensional embedding for the z-axis. Defaults to 2. + color: any column names or gene expression, etc. that will be used for coloring cells. Defaults to "ntr". + layer: the layer of data to use for the scatter plot. Defaults to "X". + highlights: the color group that will be highlighted. If highlights is a list of lists, each list is relate to + each color element. Defaults to None. + labels: an array of labels (assumed integer or categorical), one for each data sample. This will be used for + coloring the points in the plot according to their label. Note that this option is mutually exclusive to the + `values` option. Defaults to None. + values: an array of values (assumed float or continuous), one for each sample. This will be used for coloring + the points in the plot according to a colorscale associated to the total range of values. Note that this + option is mutually exclusive to the `labels` option. Defaults to None. + theme: A color theme to use for plotting. A small set of predefined themes are provided which have relatively + good aesthetics. Available themes are: {'blue', 'red', 'green', 'inferno', 'fire', 'viridis', 'darkblue', + 'darkred', 'darkgreen'}. Defaults to None. + cmap: The name of a matplotlib colormap to use for coloring or shading points. If no labels or values are passed + this will be used for shading points according to density (largely only of relevance for very large + datasets). If values are passed this will be used for shading according the value. Note that if theme is + passed then this value will be overridden by the corresponding option of the theme. Defaults to None. + color_key: the method to assign colors to categoricals. This can either be an explicit dict mapping labels to + colors (as strings of form '#RRGGBB'), or an array like object providing one color for each distinct + category being provided in `labels`. Either way this mapping will be used to color points according to the + label. Note that if theme is passed then this value will be overridden by the corresponding option of the + theme. Defaults to None. + color_key_cmap: the name of a matplotlib colormap to use for categorical coloring. If an explicit `color_key` is + not given a color mapping for categories can be generated from the label list and selecting a matching list + of colors from the given colormap. Note that if theme is passed then this value will be overridden by the + corresponding option of the theme. Defaults to None. + background: the color of the background. Usually this will be either 'white' or 'black', but any color name will + work. Ideally one wants to match this appropriately to the colors being used for points etc. This is one of + the things that themes handle for you. Note that if theme is passed then this value will be overridden by + the corresponding option of the theme. Defaults to None. + ncols: the number of columns for the figure. Defaults to 4. + pointsize: the scale of the point size. Actual point cell size is calculated as + `500.0 / np.sqrt(adata.shape[0]) * pointsize`. Defaults to None. + figsize: the width and height of a figure. Defaults to (6, 4). + show_legend: whether to display a legend of the labels. Defaults to "on data". + use_smoothed: whether to use smoothed values (i.e. M_s / M_u instead of spliced / unspliced, etc.). Defaults to + True. + aggregate: the column in adata.obs that will be used to aggregate data points. Defaults to None. + show_arrowed_spines: whether to show a pair of arrowed spines representing the basis of the scatter is currently + using. Defaults to False. + ax: the matplotlib axes object where new plots will be added to. Only applicable to drawing a single component. + Defaults to None. + sort: the method to reorder data so that high values points will be on top of background points. Can be one of + {'raw', 'abs', 'neg'}, i.e. sorted by raw data, sort by absolute values or sort by negative values. Defaults + to "raw". + save_show_or_return: whether to save, show or return the figure. If "both", it will save and plot the figure at + the same time. If "all", the figure will be saved, displayed and the associated axis and other object will + be return. Defaults to "show". + save_kwargs: A dictionary that will passed to the save_show_ret function. By default it is an empty dictionary and + the save_show_ret function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a + dictionary that properly modify those keys according to your needs. Defaults to {}. + return_all: whether to return all the scatter related variables. Defaults to False. + add_gamma_fit: whether to add the line of the gamma fitting. This will automatically turn on if `basis` points + to gene names and those genes have went through gamma fitting. Defaults to False. + frontier: whether to add the frontier. Scatter plots can be enhanced by using transparency (alpha) in order to + show area of high density and multiple scatter plots can be used to delineate a frontier. See matplotlib + tips & tricks cheatsheet (https://github.com/matplotlib/cheatsheets). Originally inspired by figures from + scEU-seq paper: https://science.sciencemag.org/content/367/6482/1151. If `contour` is set to be True, + `frontier` will be ignored as `contour` also add an outlier for data points. Defaults to False. + contour: whether to add an countor on top of scatter plots. We use tricontourf to plot contour for non-gridded + data. The shapely package was used to create a polygon of the concave hull of the scatters. With the polygon + we then check if the mean of the triangulated points are within the polygon and use this as our condition to + form the mask to create the contour. We also add the polygon shape as a frontier of the data point (similar + to when setting `frontier = True`). When the color of the data points is continuous, we will use the same + cmap as for the scatter points by default, when color is categorical, no contour will be drawn but just the + polygon. cmap can be set with `ccmap` argument. See below. This has recently changed to use seaborn's + kdeplot. Defaults to False. + ccmap: the name of a matplotlib colormap to use for coloring or shading points the contour. See above. + Defaults to None. + alpha: the point's alpha (transparency) value. Defaults to 0.1. + calpha: contour alpha value passed into sns.kdeplot. The value should be inbetween [0, 1]. Defaults to 0.4. + sym_c: whether do you want to make the limits of continuous color to be symmetric, normally this should be used + for plotting velocity, jacobian, curl, divergence or other types of data with both positive or negative + values. Defaults to False. + smooth: whether do you want to further smooth data and how much smoothing do you want. If it is `False`, no + smoothing will be applied. If `True`, smoothing based on one step diffusion of connectivity matrix + (`.uns['moment_cnn']`) will be applied. If a number larger than 1, smoothing will based on `smooth` steps of + diffusion. + dpi: the resolution of the figure in dots-per-inch. Dots per inches (dpi) determines how many pixels the figure + comprises. dpi is different from ppi or points per inches. Note that most elements like lines, markers, + texts have a size given in points so you can convert the points to inches. Matplotlib figures use Points per + inch (ppi) of 72. A line with thickness 1 point will be 1./72. inch wide. A text with fontsize 12 points + will be 12./72. inch heigh. Of course if you change the figure size in inches, points will not change, so a + larger figure in inches still has the same size of the elements.Changing the figure size is thus like taking + a piece of paper of a different size. Doing so, would of course not change the width of the line drawn with + the same pen. On the other hand, changing the dpi scales those elements. At 72 dpi, a line of 1 point size + is one pixel strong. At 144 dpi, this line is 2 pixels strong. A larger dpi will therefore act like a + magnifying glass. All elements are scaled by the magnifying power of the lens. see more details at answer 2 + by @ImportanceOfBeingErnest: + https://stackoverflow.com/questions/47633546/relationship-between-dpi-and-figure-size. Defaults to 100. + inset_dict: a dictionary of parameters in inset_ax. Example, something like {"width": "5%", "height": "50%", "loc": + 'lower left', "bbox_to_anchor": (0.85, 0.90, 0.145, 0.145), "bbox_transform": ax.transAxes, "borderpad": 0} + See more details at https://matplotlib.org/api/_as_gen/mpl_toolkits.axes_grid1.inset_locator.inset_axes.html + or https://stackoverflow.com/questions/39803385/what-does-a-4-element-tuple-argument-for-bbox-to-anchor-mean-in-matplotlib. + Defaults to {}. + marker: the marker style. marker can be either an instance of the class or the text shorthand for a particular + marker. See matplotlib.markers for more information about marker styles. Defaults to None. + group: the key in `adata.obs` corresponding to the cell group data. Defaults to None. + add_group_gamma_fit: whether to plot the cell group's gamma fit results. Defaults to False. + affine_transform_degree: transform coordinates of points according to some degree. Defaults to None. + affine_transform_A: coefficients in affine transformation Ax + b. 2D for now. Defaults to None. + affine_transform_b: bias in affine transformation Ax + b. Defaults to None. + stack_colors: whether to stack all color on the same ax passed above. Currently only support 18 sequential + matplotlib default cmaps assigning to different color groups. (#colors should be smaller than 18, reuse if + #colors > 18. TODO generate cmaps according to #colors). Defaults to False. + stack_colors_threshold: a threshold for filtering out points values < threshold when drawing each color. E.g. if + you do not want points with values < 1 showing up on axis, set threshold to be 1. Defaults to 0.001. + stack_colors_title: the title for the stack_color plot. Defaults to "stacked colors". + stack_colors_legend_size: the legend size in stack color plot. Defaults to 2. + stack_colors_cmaps: a list of cmaps that will be used to map values to color when stacking colors on the same + subplot. The order corresponds to the order of color. Defaults to None. + despline: whether to remove splines of the figure. Defaults to True. + deaxis: whether to remove axis ticks of the figure. Defaults to True. + despline_sides: which side of splines should be removed. Can be any combination of `["bottom", "right", "top", "left"]`. Defaults to None. + projection: the projection property of the matplotlib.Axes. Defaults to "2d". + **kwargs: any other kwargs that would be passed to `pyplot.scatters`. + + Raises: + ValueError: invalid adata object: lacking of required layers. + ValueError: `basis` not found in `adata.obsm`. + ValueError: invalid `x` or `y`. + ValueError: `labels` and `values` conflicted. + ValueError: invalid velocity estimation in `adata`. + ValueError: invalid velocity estimation in `adata`. + + Returns: + None would be returned by default. If `save_show_or_return` is set to be 'return' or 'all', the matplotlib axes + object of the generated plots would be returned. If `return_all` is set to be true, the list of colors used and + the font color would also be returned. + """ import matplotlib.pyplot as plt if projection == "2d": @@ -1661,11 +1808,7 @@ def scatters_single_input( Tuple[List[Axes], List[str], Literal["white", "black"]], None, ]: - """Plot an embedding as points. Currently this only works for 2D embeddings. While there are many optional - parameters to further control and tailor the plotting, you need only pass in the trained/fit umap model to get - results. This plot utility will attempt to do the hard work of avoiding overplotting issues, and make it easy to - automatically color points by a categorical labelling or numeric values. This method is intended to be used within a - Jupyter notebook with `%matplotlib inline`. + """Generate a single scatter plot from given embedding. Args: adata: an AnnData object. @@ -1678,7 +1821,7 @@ def scatters_single_input( z: the column index of the low dimensional embedding for the z-axis. Defaults to 2. color: any column names or gene expression, etc. that will be used for coloring cells. Defaults to "ntr". layer: the layer of data to use for the scatter plot. Defaults to "X". - highlights: the color group that will be highlighted. If highligts is a list of lists, each list is relate to + highlights: the color group that will be highlighted. If highlights is a list of lists, each list is relate to each color element. Defaults to None. labels: an array of labels (assumed integer or categorical), one for each data sample. This will be used for coloring the points in the plot according to their label. Note that this option is mutually exclusive to the @@ -1706,9 +1849,6 @@ def scatters_single_input( work. Ideally one wants to match this appropriately to the colors being used for points etc. This is one of the things that themes handle for you. Note that if theme is passed then this value will be overridden by the corresponding option of the theme. Defaults to None. - ncols: the number of columns for the figure. Defaults to 4. - pointsize: the scale of the point size. Actual point cell size is calculated as - `500.0 / np.sqrt(adata.shape[0]) * pointsize`. Defaults to None. figsize: the width and height of a figure. Defaults to (6, 4). show_legend: whether to display a legend of the labels. Defaults to "on data". use_smoothed: whether to use smoothed values (i.e. M_s / M_u instead of spliced / unspliced, etc.). Defaults to @@ -1721,14 +1861,6 @@ def scatters_single_input( sort: the method to reorder data so that high values points will be on top of background points. Can be one of {'raw', 'abs', 'neg'}, i.e. sorted by raw data, sort by absolute values or sort by negative values. Defaults to "raw". - save_show_or_return: whether to save, show or return the figure. If "both", it will save and plot the figure at - the same time. If "all", the figure will be saved, displayed and the associated axis and other object will - be return. Defaults to "show". - save_kwargs: A dictionary that will passed to the save_fig function. By default it is an empty dictionary and - the save_fig function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', - "transparent": True, "close": True, "verbose": True} as its parameters. Otherwise you can provide a - dictionary that properly modify those keys according to your needs. Defaults to {}. - return_all: whether to return all the scatter related variables. Defaults to False. add_gamma_fit: whether to add the line of the gamma fitting. This will automatically turn on if `basis` points to gene names and those genes have went through gamma fitting. Defaults to False. frontier: whether to add the frontier. Scatter plots can be enhanced by using transparency (alpha) in order to @@ -1746,7 +1878,6 @@ def scatters_single_input( kdeplot. Defaults to False. ccmap: the name of a matplotlib colormap to use for coloring or shading points the contour. See above. Defaults to None. - alpha: the point's alpha (transparency) value. Defaults to 0.1. calpha: contour alpha value passed into sns.kdeplot. The value should be inbetween [0, 1]. Defaults to 0.4. sym_c: whether do you want to make the limits of continuous color to be symmetric, normally this should be used for plotting velocity, jacobian, curl, divergence or other types of data with both positive or negative @@ -1755,28 +1886,13 @@ def scatters_single_input( smoothing will be applied. If `True`, smoothing based on one step diffusion of connectivity matrix (`.uns['moment_cnn']`) will be applied. If a number larger than 1, smoothing will based on `smooth` steps of diffusion. - dpi: the resolution of the figure in dots-per-inch. Dots per inches (dpi) determines how many pixels the figure - comprises. dpi is different from ppi or points per inches. Note that most elements like lines, markers, - texts have a size given in points so you can convert the points to inches. Matplotlib figures use Points per - inch (ppi) of 72. A line with thickness 1 point will be 1./72. inch wide. A text with fontsize 12 points - will be 12./72. inch heigh. Of course if you change the figure size in inches, points will not change, so a - larger figure in inches still has the same size of the elements.Changing the figure size is thus like taking - a piece of paper of a different size. Doing so, would of course not change the width of the line drawn with - the same pen. On the other hand, changing the dpi scales those elements. At 72 dpi, a line of 1 point size - is one pixel strong. At 144 dpi, this line is 2 pixels strong. A larger dpi will therefore act like a - magnifying glass. All elements are scaled by the magnifying power of the lens. see more details at answer 2 - by @ImportanceOfBeingErnest: - https://stackoverflow.com/questions/47633546/relationship-between-dpi-and-figure-size. Defaults to 100. inset_dict: a dictionary of parameters in inset_ax. Example, something like {"width": "5%", "height": "50%", "loc": 'lower left', "bbox_to_anchor": (0.85, 0.90, 0.145, 0.145), "bbox_transform": ax.transAxes, "borderpad": 0} See more details at https://matplotlib.org/api/_as_gen/mpl_toolkits.axes_grid1.inset_locator.inset_axes.html or https://stackoverflow.com/questions/39803385/what-does-a-4-element-tuple-argument-for-bbox-to-anchor-mean-in-matplotlib. Defaults to {}. - marker: the marker style. marker can be either an instance of the class or the text shorthand for a particular - marker. See matplotlib.markers for more information about marker styles. Defaults to None. group: the key in `adata.obs` corresponding to the cell group data. Defaults to None. add_group_gamma_fit: whether to plot the cell group's gamma fit results. Defaults to False. - affine_transform_degree: transform coordinates of points according to some degree. Defaults to None. affine_transform_A: coefficients in affine transformation Ax + b. 2D for now. Defaults to None. affine_transform_b: bias in affine transformation Ax + b. Defaults to None. stack_colors: whether to stack all color on the same ax passed above. Currently only support 18 sequential @@ -1792,20 +1908,9 @@ def scatters_single_input( deaxis: whether to remove axis ticks of the figure. Defaults to True. despline_sides: which side of splines should be removed. Can be any combination of `["bottom", "right", "top", "left"]`. Defaults to None. projection: the projection property of the matplotlib.Axes. Defaults to "2d". - **kwargs: any other kwargs that would be passed to `pyplot.scatters`. - - Raises: - ValueError: invalid adata object: lacking of required layers. - ValueError: `basis` not found in `adata.obsm`. - ValueError: invalid `x` or `y`. - ValueError: `labels` and `values` conflicted. - ValueError: invalid velocity estimation in `adata`. - ValueError: invalid velocity estimation in `adata`. Returns: - None would be returned by default. If `save_show_or_return` is set to be 'return' or 'all', the matplotlib axes - object of the generated plots would be returned. If `return_all` is set to be true, the list of colors used and - the font color would also be returned. + The matplotlib axes object of the generated plot. """ import matplotlib.pyplot as plt @@ -2117,7 +2222,21 @@ def scatters_single_input( return ax, color_out, font_color, stack_legend_handles -def _get_color_parameters(adata, color, layer, labels, values, theme, cmap, color_key_cmap, background, title, stack_colors, stack_colors_threshold): +def _get_color_parameters( + adata: AnnData, + color: str, + layer: str, + labels: Optional[list] = None, + values: Optional[list] = None, + theme: Optional[str] = None, + cmap: Optional[str] = None, + color_key_cmap: Optional[str] = None, + background: Optional[str] = None, + title: str = "", + stack_colors: bool = False, + stack_colors_threshold: float = 0.001, +) -> Tuple: + """A helper function to get color parameters for scatters function.""" font_color = _select_font_color(background) @@ -2194,7 +2313,14 @@ def _get_theme_for_values(): return _color, labels, values, _theme_, cmap, color_key_cmap, background, font_color, is_not_continuous, title -def _standardize_dimensions(x, y, z, n_obs, projection): +def _standardize_dimensions( + x: Union[int, str, np.ndarray, anndata._core.views.ArrayView], + y: Union[int, str, np.ndarray, anndata._core.views.ArrayView], + z: Union[int, str, np.ndarray, anndata._core.views.ArrayView], + n_obs: int, + projection: str, +) -> Tuple: + """A helper function to standardize the dimensions of x, y, z.""" if projection == "3d": x, y, z = ( _standardize_dimension(x, n_obs), @@ -2210,7 +2336,8 @@ def _standardize_dimensions(x, y, z, n_obs, projection): return x, y, z -def _standardize_dimension(dim, n_obs): +def _standardize_dimension(dim: Union[int, str, np.ndarray, anndata._core.views.ArrayView], n_obs: int) -> List: + """A helper function to standardize the dimension of a given axis.""" if type(dim) in [anndata._core.views.ArrayView, np.ndarray] and len(dim) == n_obs: return [dim] elif hasattr(dim, "__len__"): From ea4882ab119c32f9d6d8a02e494ff5affec9a7fd Mon Sep 17 00:00:00 2001 From: sichao Date: Mon, 29 Jan 2024 12:02:53 -0500 Subject: [PATCH 6/9] debug shape in 3d --- dynamo/plot/scatters.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index b25ebf4b6..9d2788000 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -1635,7 +1635,7 @@ def scatters( if type(basis) is str: basis = [basis] - n_c, n_l, n_b, n_x, n_y = ( + n_c, n_l, n_b, n_x, n_y, n_z = ( 1 if color is None else len(color), 1 if layer is None else len(layer), 1 if basis is None else len(basis), @@ -1643,11 +1643,12 @@ def scatters( # check whether it is an array 1 if y is None else 1 if type(y) in [anndata._core.views.ArrayView, np.ndarray] else len(y), # check whether it is an array + 1 if z is None else 1 if type(z) in [anndata._core.views.ArrayView, np.ndarray] else len(z), ) total_panels, ncols = ( - n_c * n_l * n_b * n_x * n_y, - min(max([n_c, n_l, n_b, n_x, n_y]), ncols), + n_c * n_l * n_b * n_x, + min(max([n_c, n_l, n_b, n_x, n_y, n_z]), ncols), ) nrow, ncol = int(np.ceil(total_panels / ncols)), ncols From 2b1ffa26fd16f8ee999167eb6388d341a7e9a028 Mon Sep 17 00:00:00 2001 From: sichao Date: Tue, 30 Jan 2024 10:19:19 -0500 Subject: [PATCH 7/9] create _get_basis_key helper func --- dynamo/plot/scatters.py | 139 ++++++++++++++++++++++++---------------- 1 file changed, 83 insertions(+), 56 deletions(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index 9d2788000..c53cf6094 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -1590,6 +1590,35 @@ def scatters( if stack_colors: color_key = None + if all([is_gene_name(adata, i) for i in basis]): + if x[0] not in ["M_s", "X_spliced", "M_t", "X_total", "spliced", "total"] and y[0] not in [ + "M_u", + "X_unspliced", + "M_n", + "X_new", + "unspliced", + "new", + ]: + if "M_t" in adata.layers.keys() and "M_n" in adata.layers.keys(): + x, y = ["M_t"], ["M_n"] + elif "X_total" in adata.layers.keys() and "X_new" in adata.layers.keys(): + x, y = ["X_total"], ["X_new"] + elif "M_s" in adata.layers.keys() and "M_u" in adata.layers.keys(): + x, y = ["M_s"], ["M_u"] + elif "X_spliced" in adata.layers.keys() and "X_unspliced" in adata.layers.keys(): + x, y = ["X_spliced"], ["X_unspliced"] + elif "spliced" in adata.layers.keys() and "unspliced" in adata.layers.keys(): + x, y = ["spliced"], ["unspliced"] + elif "total" in adata.layers.keys() and "new" in adata.layers.keys(): + x, y = ["total"], ["new"] + else: + raise ValueError( + "your adata object is corrupted. Please make sure it has at least one of the following " + "pair of layers:" + "'M_s', 'X_spliced', 'M_t', 'X_total', 'spliced', 'total' and " + "'M_u', 'X_unspliced', 'M_n', 'X_new', 'unspliced', 'new'. " + ) + if background is None: background = rcParams.get("figure.facecolor") background = to_hex(background) if type(background) is tuple else background @@ -1670,8 +1699,15 @@ def scatters( for cur_b in basis: for cur_l in layer: - main_debug("Plotting basis:%s, layer: %s" % (str(basis), str(layer))) - main_debug("colors: %s" % (str(color))) + + basis_key, cur_l_smoothed, cmap, sym_c = _get_basis_key( + adata=adata, + basis=cur_b, + layer=cur_l, + use_smoothed=use_smoothed, + cmap=cmap, + sym_c=sym_c, + ) for cur_c in color: x, y, z = _standardize_dimensions(x, y, z, adata.n_obs, projection) @@ -1729,6 +1765,8 @@ def scatters( projection=projection, scatter_kwargs=scatter_kwargs, ax_index=ax_index, + cur_l_smoothed=cur_l_smoothed, + basis_key=basis_key, ) axes_list.append(ax) @@ -1802,6 +1840,8 @@ def scatters_single_input( projection: str = "2d", scatter_kwargs: Optional[Dict[str, Any]] = None, ax_index: Optional[int] = None, + cur_l_smoothed: Optional[str] = None, + basis_key: Optional[str] = None, ) -> Union[ Axes, List[Axes], @@ -1915,66 +1955,14 @@ def scatters_single_input( """ import matplotlib.pyplot as plt - if is_gene_name(adata, basis): - if x not in ["M_s", "X_spliced", "M_t", "X_total", "spliced", "total"] and y not in [ - "M_u", - "X_unspliced", - "M_n", - "X_new", - "unspliced", - "new", - ]: - if "M_t" in adata.layers.keys() and "M_n" in adata.layers.keys(): - x, y = "M_t", "M_n" - elif "X_total" in adata.layers.keys() and "X_new" in adata.layers.keys(): - x, y = "X_total", "X_new" - elif "M_s" in adata.layers.keys() and "M_u" in adata.layers.keys(): - x, y = "M_s", "M_u" - elif "X_spliced" in adata.layers.keys() and "X_unspliced" in adata.layers.keys(): - x, y = "X_spliced", "X_unspliced" - elif "spliced" in adata.layers.keys() and "unspliced" in adata.layers.keys(): - x, y = "spliced", "unspliced" - elif "total" in adata.layers.keys() and "new" in adata.layers.keys(): - x, y = "total", "new" - else: - raise ValueError( - "your adata object is corrupted. Please make sure it has at least one of the following " - "pair of layers:" - "'M_s', 'X_spliced', 'M_t', 'X_total', 'spliced', 'total' and " - "'M_u', 'X_unspliced', 'M_n', 'X_new', 'unspliced', 'new'. " - ) - - if use_smoothed: - mapper = get_mapper() - # if #total_panel is 1, `_matplotlib_points` will create a figure. No need to create a figure here and generate a blank figure. if ax is None: figure, ax = plt.subplots() color_out = None - if layer in ["acceleration", "curvature", "divergence", "velocity_S", "velocity_T"]: - cur_l_smoothed = layer - cmap, sym_c = "bwr", True # TODO maybe use other divergent color map in the future - else: - if use_smoothed: - cur_l_smoothed = layer if layer.startswith("M_") | layer.startswith("velocity") else mapper[layer] - if layer.startswith("velocity"): - cmap, sym_c = "bwr", True - - if layer + "_" + basis in adata.obsm.keys(): - prefix = layer + "_" - elif ("X_" + basis) in adata.obsm.keys(): - prefix = "X_" - elif basis in adata.obsm.keys(): - # special case for spatial for compatibility with other packages - prefix = "" - elif is_gene_name(adata, basis): - prefix = "" - else: - raise ValueError("Please check if basis=%s exists in adata.obsm" % basis) - - basis_key = prefix + basis + basis_key = basis_key if basis_key is not None else basis + cur_l_smoothed = cur_l_smoothed if cur_l_smoothed is not None else layer main_info("plotting with basis key=%s" % basis_key, indent_level=2) main_debug("coloring scatter of cur_c: %s" % str(color)) @@ -2223,6 +2211,45 @@ def scatters_single_input( return ax, color_out, font_color, stack_legend_handles +def _get_basis_key( + adata: AnnData, + basis: str, + layer: str, + use_smoothed: bool, + cmap: Optional[str] = None, + sym_c: Optional[bool] = None, +) -> str: + """A helper function to get basis key for scatters function.""" + if use_smoothed: + mapper = get_mapper() + + cur_l_smoothed = layer + + if layer in ["acceleration", "curvature", "divergence", "velocity_S", "velocity_T"]: + cur_l_smoothed = layer + cmap, sym_c = "bwr", True # TODO maybe use other divergent color map in the future + else: + if use_smoothed: + cur_l_smoothed = layer if layer.startswith("M_") | layer.startswith("velocity") else mapper[layer] + if layer.startswith("velocity"): + cmap, sym_c = "bwr", True + + if layer + "_" + basis in adata.obsm.keys(): + prefix = layer + "_" + elif ("X_" + basis) in adata.obsm.keys(): + prefix = "X_" + elif basis in adata.obsm.keys(): + # special case for spatial for compatibility with other packages + prefix = "" + elif is_gene_name(adata, basis): + prefix = "" + else: + raise ValueError("Please check if basis=%s exists in adata.obsm" % basis) + + basis_key = prefix + basis + + return basis_key, cur_l_smoothed, cmap, sym_c + def _get_color_parameters( adata: AnnData, color: str, From 70740e39c25744cba55411ba53dff1cd459c41ce Mon Sep 17 00:00:00 2001 From: sichao Date: Tue, 30 Jan 2024 13:38:16 -0500 Subject: [PATCH 8/9] create new scatters_interactive func --- dynamo/plot/scatters.py | 334 ++++++++++++++++++++++++++++++- dynamo/vectorfield/topography.py | 8 +- 2 files changed, 337 insertions(+), 5 deletions(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index c53cf6094..e0a5e1dda 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -986,7 +986,7 @@ def _map_cur_axis(cur: str) -> Tuple[np.ndarray, str]: return points, cur_title -def scatters_interactive( +def scatters_interactive_legacy( adata: AnnData, basis: str = "umap", x: Union[int, str] = 0, @@ -1333,6 +1333,338 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str) -> None: ) +def scatters_interactive( + adata: AnnData, + basis: str = "umap", + x: Union[int, str] = 0, + y: Union[int, str] = 1, + z: Union[int, str] = 2, + color: str = "ntr", + layer: str = "X", + plot_method: str = "pv", + labels: Optional[list] = None, + values: Optional[list] = None, + cmap: Optional[str] = None, + theme: Optional[str] = None, + background: Optional[str] = None, + color_key: Union[Dict[str, str], List[str], None] = None, + color_key_cmap: Optional[str] = None, + use_smoothed: bool = True, + sym_c: bool = False, + smooth: bool = False, + save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", + save_kwargs: Dict[str, Any] = {}, + **kwargs, +): + """Plot an embedding as points with Pyvista. Currently only 3D input is supported. For 2D data, `scatters` is a + better alternative. + + The function will use the colors from matplotlib to keep consistence with other plotting functions. + + Args: + adata: an AnnData object. + basis: the reduced dimension stored in adata.obsm. The specific basis key will be constructed in the following + priority if exits: 1) specific layer input + basis 2) X_ + basis 3) basis. E.g. if basis is PCA, `scatters` + is going to look for 1) if specific layer is spliced, `spliced_pca` 2) `X_pca` (dynamo convention) 3) `pca`. + Defaults to "umap". + x: the column index of the low dimensional embedding for the x-axis. Defaults to 0. + y: the column index of the low dimensional embedding for the y-axis. Defaults to 1. + z: the column index of the low dimensional embedding for the z-axis. Defaults to 2. + color: any column names or gene expression, etc. that will be used for coloring cells. Defaults to "ntr". + layer: the layer of data to use for the scatter plot. Defaults to "X". + labels: an array of labels (assumed integer or categorical), one for each data sample. This will be used for + coloring the points in the plot according to their label. Note that this option is mutually exclusive to the + `values` option. Defaults to None. + values: an array of values (assumed float or continuous), one for each sample. This will be used for coloring + the points in the plot according to a colorscale associated to the total range of values. Note that this + option is mutually exclusive to the `labels` option. Defaults to None. + theme: A color theme to use for plotting. A small set of predefined themes are provided which have relatively + good aesthetics. Defaults to None. + cmap: The name of a matplotlib colormap to use for coloring or shading points. If no labels or values are passed + this will be used for shading points according to density (largely only of relevance for very large + datasets). If values are passed this will be used for shading according the value. Note that if theme is + passed then this value will be overridden by the corresponding option of the theme. Defaults to None. + background: the color of the background. Usually this will be either 'white' or 'black', but any color name will + work. Ideally one wants to match this appropriately to the colors being used for points etc. This is one of + the things that themes handle for you. Note that if theme is passed then this value will be overridden by + the corresponding option of the theme. Defaults to None. + color_key: the method to assign colors to categoricals. This can either be an explicit dict mapping labels to + colors (as strings of form '#RRGGBB'), or an array like object providing one color for each distinct + category being provided in `labels`. Either way this mapping will be used to color points according to the + label. Note that if theme is passed then this value will be overridden by the corresponding option of the + theme. Defaults to None. + color_key_cmap: the name of a matplotlib colormap to use for categorical coloring. If an explicit `color_key` is + not given a color mapping for categories can be generated from the label list and selecting a matching list + of colors from the given colormap. Note that if theme is passed then this value will be overridden by the + corresponding option of the theme. Defaults to None. + use_smoothed: whether to use smoothed values (i.e. M_s / M_u instead of spliced / unspliced, etc.). Defaults to + True. + sym_c: whether do you want to make the limits of continuous color to be symmetric, normally this should be used + for plotting velocity, jacobian, curl, divergence or other types of data with both positive or negative + values. Defaults to False. + smooth: whether do you want to further smooth data and how much smoothing do you want. If it is `False`, no + smoothing will be applied. If `True`, smoothing based on one-step diffusion of connectivity matrix + (`.uns['moment_cnn']`) will be applied. If a number larger than 1, smoothing will be based on `smooth` steps + of diffusion. + save_show_or_return: whether to save, show or return the figure. If "both", it will save and plot the figure at + the same time. If "all", the figure will be saved, displayed and the associated axis and other object will + be return. Defaults to "show". + save_kwargs: A dictionary that will be passed to the saving function. By default, it is an empty dictionary + and the saving function will use the {"path": None, "prefix": 'scatter', "dpi": None, "ext": 'pdf', + "title": PyVista Export, "raster": True, "painter": True} as its parameters. Otherwise, you can provide a + dictionary that properly modify those keys according to your needs. Defaults to {}. + **kwargs: any other kwargs that would be passed to `Plotter.add_points()`. + + Returns: + If `save_show_or_return` is `save`, `show` or `both`, the function will return nothing but show or save the + figure. If `save_show_or_return` is `return`, the function will return the axis object(s) that contains the + figure. + """ + + if plot_method == "pv": + try: + import pyvista as pv + except ImportError: + raise ImportError("Please install pyvista first.") + elif plot_method == "plotly": + try: + import plotly.express as px + import plotly.graph_objects as go + from plotly.subplots import make_subplots + except ImportError: + raise ImportError("Please install plotly first.") + else: + raise NotImplementedError("Current plot method not supported.") + + if background is None: + background = rcParams.get("figure.facecolor") + background = to_hex(background) if type(background) is tuple else background + # if save_show_or_return != 'save': set_figure_params('dynamo', background=_background) + + if type(x) in [int, str]: + x = [x] + if type(y) in [int, str]: + y = [y] + if type(z) in [int, str]: + z = [z] + + # make x, y, z lists of list, where each list corresponds to one coordinate set + if ( + type(x) in [anndata._core.views.ArrayView, np.ndarray] + and type(y) in [anndata._core.views.ArrayView, np.ndarray] + and type(z) in [anndata._core.views.ArrayView, np.ndarray] + and len(x) == adata.n_obs + and len(y) == adata.n_obs + and len(z) == adata.n_obs + ): + x, y, z = [x], [y], [z] + + elif hasattr(x, "__len__") and hasattr(y, "__len__") and hasattr(z, "__len__"): + x, y, z = list(x), list(y), list(z) + + assert len(x) == len(y) and len(x) == len(z), "bug: x, y, z does not have the same shape." + + if use_smoothed: + mapper = get_mapper() + + # check color, layer, basis -> convert to list + if type(color) is str: + color = [color] + if type(layer) is str: + layer = [layer] + if type(basis) is str: + basis = [basis] + + n_c, n_l, n_b, n_x, n_y, n_z = ( + 1 if color is None else len(color), + 1 if layer is None else len(layer), + 1 if basis is None else len(basis), + 1 if x is None else 1 if type(x) in [anndata._core.views.ArrayView, np.ndarray] else len(x), + 1 if y is None else 1 if type(y) in [anndata._core.views.ArrayView, np.ndarray] else len(y), + 1 if z is None else 1 if type(z) in [anndata._core.views.ArrayView, np.ndarray] else len(z), + ) + + total_panels, ncols = ( + n_c * n_l * n_b * n_x, + max([n_c, n_l, n_b, n_x, n_y, n_z]), + ) + + nrow, ncol = int(np.ceil(total_panels / ncols)), ncols + subplot_indices = [[i, j] for i in range(nrow) for j in range(ncol)] + cur_subplot = 0 + colors_list = [] + + if total_panels == 1: + pl = pv.Plotter() if plot_method == "pv" else make_subplots(rows=1, cols=1, specs=[[{"type": "scatter3d"}]]) + else: + pl = ( + pv.Plotter(shape=(nrow, ncol)) + if plot_method == "pv" + else + make_subplots(rows=nrow, cols=ncol, specs=[[{"type": "scatter3d"} for _ in range(ncol)] for _ in range(nrow)]) + ) + + def _plot_basis_layer_pv(cur_b: str, cur_l: str, cur_c: str, cur_x: str, cur_y: str, cur_z: str) -> None: + """A helper function for plotting a specific basis/layer data + + Args: + cur_b: current basis + cur_l: current layer + """ + nonlocal background, cur_subplot, cur_l_smoothed, cmap, sym_c, labels, values, theme, color_key, color_key_cmap + + cur_l_smoothed = cur_l_smoothed if cur_l_smoothed is not None else cur_l + main_debug("coloring scatter of cur_c: %s" % str(cur_c)) + + main_debug("handling coordinates, cur_x: %s, cur_y: %s, cur_z: %s" % (cur_x, cur_y, cur_z)) + + points, cur_title = map_to_points( + adata, + axis_x=cur_x, + axis_y=cur_y, + axis_z=cur_z, + basis_key=basis_key, + cur_c=cur_c, + cur_b=cur_b, + cur_l_smoothed=cur_l_smoothed, + ) + + ( + _color, + _labels, + _values, + _theme_, + _cmap, + _color_key_cmap, + background, + font_color, + is_not_continuous, + cur_title, + ) = _get_color_parameters( + adata=adata, + color=cur_c, + layer=cur_l, + labels=labels, + values=values, + theme=theme, + cmap=cmap, + color_key_cmap=color_key_cmap, + background=background, + title=cur_title, + ) + + + if smooth and not is_not_continuous: + main_debug("smooth and not continuous") + knn = adata.obsp["moments_con"] + _values = ( + calc_1nd_moment(_values, knn)[0] + if smooth in [1, True] + else calc_1nd_moment(_values, knn**smooth)[0] + ) + + colors, color_type, _ = calculate_colors( + points.values, + labels=_labels, + values=_values, + cmap=_cmap, + color_key=color_key, + color_key_cmap=_color_key_cmap, + background=background, + sym_c=sym_c, + ) + + colors_list.append(colors) + + if plot_method == "pv": + if total_panels > 1: + pl.subplot(subplot_indices[cur_subplot][0], subplot_indices[cur_subplot][1]) + + pvdataset = pv.PolyData(points.values) + pvdataset.point_data["colors"] = np.stack(colors) + pl.add_points(pvdataset, scalars="colors", preference='point', rgb=True, cmap=_cmap, **kwargs) + + if color_type == "labels": + type_color_dict = {cell_type: cell_color for cell_type, cell_color in zip(_labels, colors)} + type_color_pair = [[k, v] for k, v in type_color_dict.items()] + pl.add_legend(labels=type_color_pair) + else: + pl.add_scalar_bar() # TODO: fix the bug that scalar bar only works in the first plot + + pl.add_text(cur_title) + pl.add_axes(xlabel=points.columns[0], ylabel=points.columns[1], zlabel=points.columns[2]) + elif plot_method == "plotly": + + pl.add_trace( + go.Scatter3d( + x=points.iloc[:, 0], + y=points.iloc[:, 1], + z=points.iloc[:, 2], + mode="markers", + marker=dict( + color=colors, + ), + text=_labels if color_type == "labels" else _values, + **kwargs, + ), + row=subplot_indices[cur_subplot][0] + 1, col=subplot_indices[cur_subplot][1] + 1, + ) + + pl.update_layout( + scene=dict( + xaxis_title=points.columns[0], + yaxis_title=points.columns[1], + zaxis_title=points.columns[2] + ), + ) + + cur_subplot += 1 + + for cur_b in basis: + for cur_l in layer: + main_debug("Plotting basis:%s, layer: %s" % (str(basis), str(layer))) + main_debug("colors: %s" % (str(color))) + + if cur_l in ["acceleration", "curvature", "divergence", "velocity_S", "velocity_T"]: + cur_l_smoothed = cur_l + cmap, sym_c = "bwr", True # TODO maybe use other divergent color map in the future + else: + if use_smoothed: + cur_l_smoothed = cur_l if cur_l.startswith("M_") | cur_l.startswith("velocity") else mapper[cur_l] + if cur_l.startswith("velocity"): + cmap, sym_c = "bwr", True + + if cur_l + "_" + cur_b in adata.obsm.keys(): + prefix = cur_l + "_" + elif ("X_" + cur_b) in adata.obsm.keys(): + prefix = "X_" + elif cur_b in adata.obsm.keys(): + # special case for spatial for compatibility with other packages + prefix = "" + else: + raise ValueError("Please check if basis=%s exists in adata.obsm" % basis) + + basis_key = prefix + cur_b + main_info("plotting with basis key=%s" % basis_key, indent_level=2) + + for cur_c in color: + for cur_x, cur_y, cur_z in zip(x, y, z): + _plot_basis_layer_pv(cur_b, cur_l, cur_c, cur_x, cur_y, cur_z) + + + return save_pyvista_plotter( + pl=pl, + colors_list=colors_list, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) if plot_method == "pv" else save_plotly_figure( + pl=pl, + colors_list=colors_list, + save_show_or_return=save_show_or_return, + save_kwargs=save_kwargs, + ) + + def scatters( adata: AnnData, basis: str = "umap", diff --git a/dynamo/vectorfield/topography.py b/dynamo/vectorfield/topography.py index e407ae307..7cc31bb10 100755 --- a/dynamo/vectorfield/topography.py +++ b/dynamo/vectorfield/topography.py @@ -882,8 +882,8 @@ def func(x): "Xss": Xss, "ftype": ftype, "confidence": confidence, - "NCx": {str(index): array for index, array in enumerate(NCx)}, - "NCy": {str(index): array for index, array in enumerate(NCy)}, + "NCx": {str(index): array for index, array in enumerate(NCx)} if NCx is not None else None, + "NCy": {str(index): array for index, array in enumerate(NCy)} if NCy is not None else None, "separatrices": None, "fp_ind": fp_ind, } @@ -897,8 +897,8 @@ def func(x): "Xss": Xss, "ftype": ftype, "confidence": confidence, - "NCx": {str(index): array for index, array in enumerate(NCx)}, - "NCy": {str(index): array for index, array in enumerate(NCy)}, + "NCx": {str(index): array for index, array in enumerate(NCx)} if NCx is not None else None, + "NCy": {str(index): array for index, array in enumerate(NCy)} if NCy is not None else None, "separatrices": None, "fp_ind": fp_ind, } From 8ce4d090644b3a398d1f58379f29500bb0579385 Mon Sep 17 00:00:00 2001 From: sichao Date: Tue, 30 Jan 2024 15:45:04 -0500 Subject: [PATCH 9/9] create _validate_parameters func --- dynamo/plot/scatters.py | 213 +++++++++++++++++++--------------------- 1 file changed, 101 insertions(+), 112 deletions(-) diff --git a/dynamo/plot/scatters.py b/dynamo/plot/scatters.py index e0a5e1dda..fd2902908 100755 --- a/dynamo/plot/scatters.py +++ b/dynamo/plot/scatters.py @@ -1436,60 +1436,21 @@ def scatters_interactive( else: raise NotImplementedError("Current plot method not supported.") - if background is None: - background = rcParams.get("figure.facecolor") - background = to_hex(background) if type(background) is tuple else background - # if save_show_or_return != 'save': set_figure_params('dynamo', background=_background) - - if type(x) in [int, str]: - x = [x] - if type(y) in [int, str]: - y = [y] - if type(z) in [int, str]: - z = [z] - - # make x, y, z lists of list, where each list corresponds to one coordinate set - if ( - type(x) in [anndata._core.views.ArrayView, np.ndarray] - and type(y) in [anndata._core.views.ArrayView, np.ndarray] - and type(z) in [anndata._core.views.ArrayView, np.ndarray] - and len(x) == adata.n_obs - and len(y) == adata.n_obs - and len(z) == adata.n_obs - ): - x, y, z = [x], [y], [z] - - elif hasattr(x, "__len__") and hasattr(y, "__len__") and hasattr(z, "__len__"): - x, y, z = list(x), list(y), list(z) - - assert len(x) == len(y) and len(x) == len(z), "bug: x, y, z does not have the same shape." - if use_smoothed: mapper = get_mapper() # check color, layer, basis -> convert to list - if type(color) is str: - color = [color] - if type(layer) is str: - layer = [layer] - if type(basis) is str: - basis = [basis] - - n_c, n_l, n_b, n_x, n_y, n_z = ( - 1 if color is None else len(color), - 1 if layer is None else len(layer), - 1 if basis is None else len(basis), - 1 if x is None else 1 if type(x) in [anndata._core.views.ArrayView, np.ndarray] else len(x), - 1 if y is None else 1 if type(y) in [anndata._core.views.ArrayView, np.ndarray] else len(y), - 1 if z is None else 1 if type(z) in [anndata._core.views.ArrayView, np.ndarray] else len(z), - ) - - total_panels, ncols = ( - n_c * n_l * n_b * n_x, - max([n_c, n_l, n_b, n_x, n_y, n_z]), + basis, background, color, layer, x, y, z, nrow, ncol, total_panels = _validate_parameters( + adata=adata, + basis=basis, + background=background, + color=color, + layer=layer, + x=x, + y=y, + z=z, ) - nrow, ncol = int(np.ceil(total_panels / ncols)), ncols subplot_indices = [[i, j] for i in range(nrow) for j in range(ncol)] cur_subplot = 0 colors_list = [] @@ -1518,7 +1479,7 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str, cur_c: str, cur_x: str, cur_y: main_debug("handling coordinates, cur_x: %s, cur_y: %s, cur_z: %s" % (cur_x, cur_y, cur_z)) - points, cur_title = map_to_points( + points, cur_title = _map_to_points( adata, axis_x=cur_x, axis_y=cur_y, @@ -1527,6 +1488,7 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str, cur_c: str, cur_x: str, cur_y: cur_c=cur_c, cur_b=cur_b, cur_l_smoothed=cur_l_smoothed, + projection="3d", ) ( @@ -1648,6 +1610,7 @@ def _plot_basis_layer_pv(cur_b: str, cur_l: str, cur_c: str, cur_x: str, cur_y: main_info("plotting with basis key=%s" % basis_key, indent_level=2) for cur_c in color: + x, y, z = _standardize_dimensions(x, y, z, adata.n_obs, projection="3d") for cur_x, cur_y, cur_z in zip(x, y, z): _plot_basis_layer_pv(cur_b, cur_l, cur_c, cur_x, cur_y, cur_z) @@ -1922,40 +1885,6 @@ def scatters( if stack_colors: color_key = None - if all([is_gene_name(adata, i) for i in basis]): - if x[0] not in ["M_s", "X_spliced", "M_t", "X_total", "spliced", "total"] and y[0] not in [ - "M_u", - "X_unspliced", - "M_n", - "X_new", - "unspliced", - "new", - ]: - if "M_t" in adata.layers.keys() and "M_n" in adata.layers.keys(): - x, y = ["M_t"], ["M_n"] - elif "X_total" in adata.layers.keys() and "X_new" in adata.layers.keys(): - x, y = ["X_total"], ["X_new"] - elif "M_s" in adata.layers.keys() and "M_u" in adata.layers.keys(): - x, y = ["M_s"], ["M_u"] - elif "X_spliced" in adata.layers.keys() and "X_unspliced" in adata.layers.keys(): - x, y = ["X_spliced"], ["X_unspliced"] - elif "spliced" in adata.layers.keys() and "unspliced" in adata.layers.keys(): - x, y = ["spliced"], ["unspliced"] - elif "total" in adata.layers.keys() and "new" in adata.layers.keys(): - x, y = ["total"], ["new"] - else: - raise ValueError( - "your adata object is corrupted. Please make sure it has at least one of the following " - "pair of layers:" - "'M_s', 'X_spliced', 'M_t', 'X_total', 'spliced', 'total' and " - "'M_u', 'X_unspliced', 'M_n', 'X_new', 'unspliced', 'new'. " - ) - - if background is None: - background = rcParams.get("figure.facecolor") - background = to_hex(background) if type(background) is tuple else background - # if save_show_or_return != 'save': set_figure_params('dynamo', background=_background) - if not (affine_transform_degree is None): affine_transform_A = gen_rotation_2d(affine_transform_degree) affine_transform_b = 0 @@ -1982,37 +1911,18 @@ def scatters( if figsize is None: figsize = plt.rcParams["figsize"] - if type(x) in [int, str]: - x = [x] - if type(y) in [int, str]: - y = [y] - if type(z) in [int, str]: - z = [z] - - if type(color) is str: - color = [color] - if type(layer) is str: - layer = [layer] - if type(basis) is str: - basis = [basis] - - n_c, n_l, n_b, n_x, n_y, n_z = ( - 1 if color is None else len(color), - 1 if layer is None else len(layer), - 1 if basis is None else len(basis), - 1 if x is None else 1 if type(x) in [anndata._core.views.ArrayView, np.ndarray] else len(x), - # check whether it is an array - 1 if y is None else 1 if type(y) in [anndata._core.views.ArrayView, np.ndarray] else len(y), - # check whether it is an array - 1 if z is None else 1 if type(z) in [anndata._core.views.ArrayView, np.ndarray] else len(z), + basis, background, color, layer, x, y, z, nrow, ncol, total_panels = _validate_parameters( + adata=adata, + basis=basis, + background=background, + color=color, + layer=layer, + x=x, + y=y, + z=z, + ncols=ncols, ) - total_panels, ncols = ( - n_c * n_l * n_b * n_x, - min(max([n_c, n_l, n_b, n_x, n_y, n_z]), ncols), - ) - nrow, ncol = int(np.ceil(total_panels / ncols)), ncols - if total_panels > 1 and ax is None: figure = plt.figure( None, @@ -2543,6 +2453,85 @@ def scatters_single_input( return ax, color_out, font_color, stack_legend_handles +def _validate_parameters( + adata: AnnData, + basis: Union[str, List[str]], + background: Optional[str], + color: Union[str, List[str]], + layer: Union[str, List[str]], + x: Union[str, List[str], int, List[int], None], + y: Union[str, List[str], int, List[int], None], + z: Union[str, List[str], int, List[int], None], + ncols: int = 4, +) -> Tuple: + """A helper function to validate parameters for scatters function.""" + if all([is_gene_name(adata, i) for i in basis]): + if x[0] not in ["M_s", "X_spliced", "M_t", "X_total", "spliced", "total"] and y[0] not in [ + "M_u", + "X_unspliced", + "M_n", + "X_new", + "unspliced", + "new", + ]: + if "M_t" in adata.layers.keys() and "M_n" in adata.layers.keys(): + x, y = ["M_t"], ["M_n"] + elif "X_total" in adata.layers.keys() and "X_new" in adata.layers.keys(): + x, y = ["X_total"], ["X_new"] + elif "M_s" in adata.layers.keys() and "M_u" in adata.layers.keys(): + x, y = ["M_s"], ["M_u"] + elif "X_spliced" in adata.layers.keys() and "X_unspliced" in adata.layers.keys(): + x, y = ["X_spliced"], ["X_unspliced"] + elif "spliced" in adata.layers.keys() and "unspliced" in adata.layers.keys(): + x, y = ["spliced"], ["unspliced"] + elif "total" in adata.layers.keys() and "new" in adata.layers.keys(): + x, y = ["total"], ["new"] + else: + raise ValueError( + "your adata object is corrupted. Please make sure it has at least one of the following " + "pair of layers:" + "'M_s', 'X_spliced', 'M_t', 'X_total', 'spliced', 'total' and " + "'M_u', 'X_unspliced', 'M_n', 'X_new', 'unspliced', 'new'. " + ) + + if background is None: + background = rcParams.get("figure.facecolor") + background = to_hex(background) if type(background) is tuple else background + # if save_show_or_return != 'save': set_figure_params('dynamo', background=_background) + + if type(x) in [int, str]: + x = [x] + if type(y) in [int, str]: + y = [y] + if type(z) in [int, str]: + z = [z] + + if type(color) is str: + color = [color] + if type(layer) is str: + layer = [layer] + if type(basis) is str: + basis = [basis] + + n_c, n_l, n_b, n_x, n_y, n_z = ( + 1 if color is None else len(color), + 1 if layer is None else len(layer), + 1 if basis is None else len(basis), + 1 if x is None else 1 if type(x) in [anndata._core.views.ArrayView, np.ndarray] else len(x), + # check whether it is an array + 1 if y is None else 1 if type(y) in [anndata._core.views.ArrayView, np.ndarray] else len(y), + # check whether it is an array + 1 if z is None else 1 if type(z) in [anndata._core.views.ArrayView, np.ndarray] else len(z), + ) + + total_panels, ncols = ( + n_c * n_l * n_b * n_x, + min(max([n_c, n_l, n_b, n_x, n_y, n_z]), ncols), + ) + nrow, ncol = int(np.ceil(total_panels / ncols)), ncols + + return basis, background, color, layer, x, y, z, nrow, ncol, total_panels + def _get_basis_key( adata: AnnData, basis: str,