diff --git a/archetypes/datasets/make_archetypal_dataset.py b/archetypes/datasets/make_archetypal_dataset.py index 4aee894..fa6eedb 100644 --- a/archetypes/datasets/make_archetypal_dataset.py +++ b/archetypes/datasets/make_archetypal_dataset.py @@ -57,10 +57,17 @@ def make_archetypal_dataset( A = [np.zeros((s_i, a_i)) for s_i, a_i in zip(shape, n_archetypes)] for A_i, labels_i in zip(A, labels): + l_i_prev = -1 for i, l_i in enumerate(labels_i): - alpha_i = [alpha] * A_i.shape[1] - alpha_i[l_i] = 1 - A_i[i, :] = generator.dirichlet(alpha_i) + if l_i_prev != l_i: + alpha_i = [0] * A_i.shape[1] + alpha_i[l_i] = 1 + A_i[i, :] = alpha_i + l_i_prev = l_i + else: + alpha_i = [alpha] * A_i.shape[1] + alpha_i[l_i] = 1 + A_i[i, :] = generator.dirichlet(alpha_i) X = einsum(A, archetypes) diff --git a/archetypes/datasets/permutations.py b/archetypes/datasets/permutations.py index 202bcc3..b8f57b6 100644 --- a/archetypes/datasets/permutations.py +++ b/archetypes/datasets/permutations.py @@ -88,9 +88,13 @@ def sort_by_archetype_similarity(data, alphas): data, info = permute_dataset(data, perms) labels = [np.argmax(a, axis=1) for a in alphas] + scores = [np.max(a, axis=1) for a in alphas] labels = [labels[i][perms[i]] for i in range(data.ndim)] + scores = [scores[i][perms[i]] for i in range(data.ndim)] info["labels"] = labels + info["scores"] = scores + info["n_archetypes"] = [ai.shape[1] for ai in alphas] return data, info diff --git a/archetypes/visualization/bisimplex.py b/archetypes/visualization/bisimplex.py index dabd756..a6d20ad 100644 --- a/archetypes/visualization/bisimplex.py +++ b/archetypes/visualization/bisimplex.py @@ -2,19 +2,8 @@ import matplotlib.pyplot as plt import numpy as np -from archetypes.visualization import simplex - - -def _create_palette(saturation, value, n_colors, int_colors=3): - hue = np.linspace(0, 1, n_colors, endpoint=False) - hue = np.hstack([hue[i::int_colors] for i in range(int_colors)]) - saturation = np.full(n_colors, saturation) - value = np.full(n_colors, value) - # convert to RGB - c = mpl.colors.hsv_to_rgb(np.vstack([hue, saturation, value]).T) - # Create palette - palette = mpl.colors.ListedColormap(c) - return palette +from .simplex import simplex +from .utils import create_palette def bisimplex(alphas, archetypes, ax=None, **kwargs): @@ -44,7 +33,7 @@ def bisimplex(alphas, archetypes, ax=None, **kwargs): n_archetypes = archetypes.shape # Get the colors for the vertices of the polytopes - palette = _create_palette( + palette = create_palette( saturation=0.35, value=0.9, n_colors=n_archetypes[0] + n_archetypes[1], int_colors=1 ) @@ -85,7 +74,7 @@ def bisimplex(alphas, archetypes, ax=None, **kwargs): # Use grayscale palette from matplotlib between 0 and 1 # archetypes[:] = .6 - palette = mpl.colormaps["Grays"] + palette = mpl.colormaps["Greys"] for x_i, y_i, c_i, a_i in zip( xx.flatten(), yy.flatten(), archetypes.flatten(), archetypes_scaled.flatten() diff --git a/archetypes/visualization/heatmap.py b/archetypes/visualization/heatmap.py index f164980..3c3ffd3 100644 --- a/archetypes/visualization/heatmap.py +++ b/archetypes/visualization/heatmap.py @@ -1,9 +1,12 @@ import matplotlib.pyplot as plt import numpy as np +from matplotlib.colors import LinearSegmentedColormap, to_rgb from matplotlib.patches import Polygon +from .utils import create_palette -def heatmap(data, labels=None, ax=None, **kwargs): + +def heatmap(data, labels=None, n_archetypes=None, scores=None, ax=None, **kwargs): """Plot a heatmap of the data. If labels are provided, the heatmap is divided into cells. Parameters @@ -11,7 +14,13 @@ def heatmap(data, labels=None, ax=None, **kwargs): data: np.ndarray The data to plot. labels: list of np.ndarray or None - The labels to use to divide the heatmap into cells. If None, no labels are used. + The labels values to use for the plot. + If None, the labels values are computed from the labels. + n_archetypes: list of int or None + The number of archetypes for each dimension. + If None, the number of archetypes is computed from the labels. + scores: list of np.ndarray or None + The scores values to use for the plot. ax: matplotlib.pyplot.axes or None The axes to plot on. If None, a new figure and axes is created. kwargs: dict @@ -33,6 +42,8 @@ def heatmap(data, labels=None, ax=None, **kwargs): if "cmap" not in kwargs: kwargs["cmap"] = "Greys" + data_size = max(data.shape) + # Plot line if labels[i] != labels[i+1] if labels is not None: # check labels is a list of 2 arrays @@ -47,22 +58,87 @@ def heatmap(data, labels=None, ax=None, **kwargs): f"labels must be a list of 2 arrays, got {type(labels[0])} and {type(labels[1])}" ) - labels_h = np.concatenate([[-1], labels[1].flatten(), [-1]]) - labels_v = np.concatenate([[-1], labels[0].flatten(), [-1]]) + labels_h = np.concatenate([labels[1].flatten()]) + labels_v = np.concatenate([labels[0].flatten()]) - polygon_kwargs = {"color": "r", "lw": 1} + polygon_kwargs = {"color": "k", "lw": 1} for i in range(len(labels_h) - 1): if labels_h[i] != labels_h[i + 1]: - line = Polygon(np.array([[i, 0], [i, data.shape[0]]]) - 0.5, **polygon_kwargs) + line = Polygon( + np.array([[i + 1, 0], [i + 1, data.shape[0]]]) - 0.5, **polygon_kwargs + ) ax.add_patch(line) for i in range(len(labels_v) - 1): if labels_v[i] != labels_v[i + 1]: - line = Polygon(np.array([[0, i], [data.shape[1], i]]) - 0.5, **polygon_kwargs) + line = Polygon( + np.array([[0, i + 1], [data.shape[1], i + 1]]) - 0.5, **polygon_kwargs + ) ax.add_patch(line) - ax.matshow(data, rasterized=True, **kwargs) + if n_archetypes is None: + n_archetypes = [len(np.unique(labels[0])), len(np.unique(labels[1]))] + + # Add a rectangle to frame the data + rect = Polygon( + np.array( + [[0, 0], [data.shape[1], 0], [data.shape[1], data.shape[0]], [0, data.shape[0]]] + ) + - 0.5, + fill=False, + **polygon_kwargs, + ) + ax.add_patch(rect) + + palette = create_palette( + saturation=0.35, value=0.9, n_colors=n_archetypes[0] + n_archetypes[1], int_colors=1 + ) + colors = palette(np.arange(n_archetypes[0] + n_archetypes[1])) + colors_1 = colors[: n_archetypes[0]] + colors_2 = colors[n_archetypes[0] : n_archetypes[0] + n_archetypes[1]] + + # Plot archetypes + counts = [np.count_nonzero(labels[0] == i) for i in range(n_archetypes[0])] + counts = np.cumsum(counts) + counts = np.concatenate([[0], counts]) - 0.5 + + arch_factor = 0.05 * data_size + + if scores is None: + scores = [np.ones_like(labels[0]), np.ones_like(labels[1])] + + for c, (i0, i1) in zip(colors_1, zip(counts, counts[1:])): + c1 = np.array(to_rgb(c)) + c2 = np.array([1, 1, 1]) + + ax.imshow( + scores[0][int(i0 + 0.5) : int(i1 + 0.5)][::-1].reshape(-1, 1), + extent=[-arch_factor, -2 * arch_factor, i0, i1], + cmap=LinearSegmentedColormap.from_list("c", [c2, c1]), + interpolation="none", + vmax=1, + vmin=0, + ) + + counts = [np.count_nonzero(labels[1] == i) for i in range(n_archetypes[1])] + counts = np.cumsum(counts) + counts = np.concatenate([[0], counts]) - 0.5 + + for c, (i0, i1) in zip(colors_2, zip(counts, counts[1:])): + c1 = np.array(to_rgb(c)) + c2 = np.array([1, 1, 1]) + + ax.imshow( + scores[1][int(i0 + 0.5) : int(i1 + 0.5)].reshape(1, -1), + extent=[i0, i1, -arch_factor, -2 * arch_factor], + cmap=LinearSegmentedColormap.from_list("c", [c2, c1]), + interpolation="none", + vmax=1, + vmin=0, + ) + + ax.matshow(data, interpolation="none", **kwargs) # set aspect ratio to equal ax.set_aspect("equal") @@ -75,9 +151,8 @@ def heatmap(data, labels=None, ax=None, **kwargs): xlim = ax.get_xlim() ylim = ax.get_ylim() - exp_factor = 0.01 * max(np.abs(np.diff(xlim)), np.abs(np.diff(ylim))) - - ax.set_xlim(xlim[0] - exp_factor, xlim[1] + exp_factor) - ax.set_ylim(ylim[0] + exp_factor, ylim[1] - exp_factor) + lim_factor = 0.01 * data_size + ax.set_xlim(xlim[0] - lim_factor, xlim[1] + lim_factor) + ax.set_ylim(ylim[0] + lim_factor, ylim[1] - lim_factor) return ax diff --git a/archetypes/visualization/simplex.py b/archetypes/visualization/simplex.py index 4602690..12bfee9 100644 --- a/archetypes/visualization/simplex.py +++ b/archetypes/visualization/simplex.py @@ -88,7 +88,7 @@ def simplex( for p1, p2 in edges: x1, y1 = p1 x2, y2 = p2 - ax.plot([x1, x2], [y1, y2], "-", linewidth=0.75, color="lightgray", zorder=0) + ax.plot([x1, x2], [y1, y2], "-", linewidth=1, color="lightgray", zorder=0) # ax.plot(vertices[:, 0], vertices[:, 1], "o", color="black", alpha=1) if show_vertices: diff --git a/archetypes/visualization/utils.py b/archetypes/visualization/utils.py new file mode 100644 index 0000000..6aaf7e5 --- /dev/null +++ b/archetypes/visualization/utils.py @@ -0,0 +1,14 @@ +import matplotlib as mpl +import numpy as np + + +def create_palette(saturation, value, n_colors, int_colors=3): + hue = np.linspace(0, 1, n_colors, endpoint=False) + hue = np.hstack([hue[i::int_colors] for i in range(int_colors)]) + saturation = np.full(n_colors, saturation) + value = np.full(n_colors, value) + # convert to RGB + c = mpl.colors.hsv_to_rgb(np.vstack([hue, saturation, value]).T) + # Create palette + palette = mpl.colors.ListedColormap(c) + return palette