# plot_tools

> A set of convenience functions used for producing plots in `dabest`.

- order: 5

In [None]:
#| default_exp plot_tools

In [None]:
#| export
from __future__ import annotations

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev
nbdev.nbdev_export()

In [None]:
#| export
import math
import warnings
import itertools
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.axes as axes
import matplotlib.patches as mpatches
from collections import defaultdict
from typing import List, Tuple, Dict, Iterable, Union
from pandas.api.types import CategoricalDtype
from matplotlib.colors import ListedColormap

In [None]:
#| export
def halfviolin(v, half="right", fill_color="k", alpha=1, line_color="k", line_width=0):
    for b in v["bodies"]:
        V = b.get_paths()[0].vertices

        mean_vertical = np.mean(V[:, 0])
        mean_horizontal = np.mean(V[:, 1])

        if half == "right":
            V[:, 0] = np.clip(V[:, 0], mean_vertical, np.inf)
        elif half == "left":
            V[:, 0] = np.clip(V[:, 0], -np.inf, mean_vertical)
        elif half == "bottom":
            V[:, 1] = np.clip(V[:, 1], -np.inf, mean_horizontal)
        elif half == "top":
            V[:, 1] = np.clip(V[:, 1], mean_horizontal, np.inf)

        b.set_color(fill_color)
        b.set_alpha(alpha)
        b.set_edgecolor(line_color)
        b.set_linewidth(line_width)


def get_swarm_spans(coll):
    """
    Given a matplotlib Collection, will obtain the x and y spans
    for the collection. Will return None if this fails.
    """
    if coll is None:
        raise ValueError("The collection `coll` parameter cannot be None")

    x, y = np.array(coll.get_offsets()).T
    try:
        return x.min(), x.max(), y.min(), y.max()
    except ValueError as e:
        warnings.warn(f"Failed to calculate spans for the collection. Details: {e}")
        return None


def error_bar(
    data: pd.DataFrame,  # This DataFrame should be in 'long' format.
    x: str,  # x column to be plotted.
    y: str,  # y column to be plotted.
    type: str = "mean_sd",  # Choose from ['mean_sd', 'median_quartiles']. Plots the summary statistics for each group. If 'mean_sd', then the mean and standard deviation of each group is plotted as a gapped line. If 'median_quantiles', then the median and 25th and 75th percentiles of each group is plotted instead.
    offset: float = 0.2,  # Give a single float (that will be used as the x-offset of all gapped lines), or an iterable containing the list of x-offsets.
    ax=None,  # If a matplotlib Axes object is specified, the gapped lines will be plotted in order on this axes. If None, the current axes (plt.gca()) is used.
    line_color="black",  # The color of the gapped lines.
    gap_width_percent=1,  # The width of the gap in the gapped lines, as a percentage of the y-axis span.
    pos: list = [
        0,
        1,
    ],  # The positions of the error bars for the sankey_error_bar method.
    method: str = "gapped_lines",  # The method to use for drawing the error bars. Options are: 'gapped_lines', 'proportional_error_bar', and 'sankey_error_bar'.
    **kwargs: dict,
):
    """
    Function to plot the standard deviations as vertical errorbars.
    The mean is a gap defined by negative space.

    This function combines the functionality of gapped_lines(),
    proportional_error_bar(), and sankey_error_bar().

    """

    if gap_width_percent < 0 or gap_width_percent > 100:
        raise ValueError("`gap_width_percent` must be between 0 and 100.")
    if method not in ["gapped_lines", "proportional_error_bar", "sankey_error_bar"]:
        raise ValueError(
            "Invalid `method`. Must be one of 'gapped_lines', \
                         'proportional_error_bar', or 'sankey_error_bar'."
        )

    if ax is None:
        ax = plt.gca()
    ax_ylims = ax.get_ylim()
    ax_yspan = np.abs(ax_ylims[1] - ax_ylims[0])
    gap_width = ax_yspan * gap_width_percent / 100

    keys = kwargs.keys()
    if "clip_on" not in keys:
        kwargs["clip_on"] = False

    if "zorder" not in keys:
        kwargs["zorder"] = 5

    if "lw" not in keys:
        kwargs["lw"] = 2.0

    if isinstance(data[x].dtype, pd.CategoricalDtype):
        group_order = pd.unique(data[x]).categories
    else:
        group_order = pd.unique(data[x])

    means = data.groupby(x)[y].mean().reindex(index=group_order)

    if method in ["proportional_error_bar", "sankey_error_bar"]:
        g = lambda x: np.sqrt(
            (np.sum(x) * (len(x) - np.sum(x))) / (len(x) * len(x) * len(x))
        )
        sd = data.groupby(x)[y].apply(g)
    else:
        sd = data.groupby(x)[y].std().reindex(index=group_order)

    lower_sd = means - sd
    upper_sd = means + sd

    if (lower_sd < ax_ylims[0]).any() or (upper_sd > ax_ylims[1]).any():
        kwargs["clip_on"] = True

    medians = data.groupby(x)[y].median().reindex(index=group_order)
    quantiles = (
        data.groupby(x)[y].quantile([0.25, 0.75]).unstack().reindex(index=group_order)
    )
    lower_quartiles = quantiles[0.25]
    upper_quartiles = quantiles[0.75]

    if type == "mean_sd":
        central_measures = means
        lows = lower_sd
        highs = upper_sd
    elif type == "median_quartiles":
        central_measures = medians
        lows = lower_quartiles
        highs = upper_quartiles
    else:
        raise ValueError("Only accepted values for type are ['mean_sd', 'median_quartiles']")

    n_groups = len(central_measures)

    if isinstance(line_color, str):
        custom_palette = np.repeat(line_color, n_groups)
    else:
        if len(line_color) != n_groups:
            err1 = "{} groups are being plotted, but ".format(n_groups)
            err2 = "{} colors(s) were supplied in `line_color`.".format(len(line_color))
            raise ValueError(err1 + err2)
        custom_palette = line_color

    try:
        len_offset = len(offset)
    except TypeError:
        offset = np.repeat(offset, n_groups)
        len_offset = len(offset)

    if len_offset != n_groups:
        err1 = "{} groups are being plotted, but ".format(n_groups)
        err2 = "{} offset(s) were supplied in `offset`.".format(len_offset)
        raise ValueError(err1 + err2)

    kwargs["zorder"] = kwargs["zorder"]

    for xpos, central_measure in enumerate(central_measures):
        kwargs["color"] = custom_palette[xpos]

        if method == "sankey_error_bar":
            _xpos = pos[xpos] + offset[xpos]
        else:
            _xpos = xpos + offset[xpos]

        low = lows[xpos]
        high = highs[xpos]
        if low == high == central_measure:
            low_to_mean = mlines.Line2D(
                [_xpos, _xpos], [low, central_measure], **kwargs
            )
            ax.add_line(low_to_mean)

            mean_to_high = mlines.Line2D(
                [_xpos, _xpos], [central_measure, high], **kwargs
            )
            ax.add_line(mean_to_high)
        else:
            low_to_mean = mlines.Line2D(
                [_xpos, _xpos], [low, central_measure - gap_width], **kwargs
            )
            ax.add_line(low_to_mean)

            mean_to_high = mlines.Line2D(
                [_xpos, _xpos], [central_measure + gap_width, high], **kwargs
            )
            ax.add_line(mean_to_high)


def check_data_matches_labels(
    labels,  # list of input labels
    data,  # Pandas Series of input data
    side: str,  # 'left' or 'right' on the sankey diagram
):
    """
    Function to check that the labels and data match in the sankey diagram.
    And enforce labels and data to be lists.
    Raises an exception if the labels and data do not match.
    """
    if len(labels) > 0:
        if isinstance(data, list):
            data = set(data)
        if isinstance(data, pd.Series):
            data = set(data.unique())
        if isinstance(labels, list):
            labels = set(labels)
        if labels != data:
            msg = "\n"
            if len(labels) <= 20:
                msg = "Labels: " + ",".join(labels) + "\n"
            if len(data) < 20:
                msg += "Data: " + ",".join(data)
            raise Exception(f"{side} labels and data do not match.{msg}")


def normalize_dict(nested_dict, target):
    """
    Normalizes the values in a nested dictionary based on a target dictionary.

    This function iterates through a nested dictionary, calculates the sum of values for each key
    across all sub-dictionaries, and then normalizes these values according to a target dictionary.
    The normalization is performed such that the values in each sub-dictionary are proportionally
    scaled to match the corresponding 'right' values in the target dictionary.

    Parameters:
    nested_dict (dict of dict): A nested dictionary where each key maps to another dictionary.
                                The values in these inner dictionaries are subject to normalization.
    target (dict): A dictionary with the target values for normalization. Each key in nested_dict
                   should have a corresponding key in target, and each target[key] should be a
                   dictionary with a 'right' key containing the target normalization value.

    Returns:
    dict: The normalized nested dictionary. The original nested_dict is modified in place.

    Note:
    - If the sum of values for a particular key in nested_dict is zero, the normalized value is set to 0.
    - If a key in a sub-dictionary of nested_dict does not exist in the target dictionary, the
      corresponding 'right' value from the target dictionary is directly assigned.
    - The function modifies the input nested_dict in place and also returns it.
    """
    val = {}
    for key in nested_dict.keys():
        val[key] = np.sum(
            [
                nested_dict[sub_key][key]
                for sub_key in nested_dict.keys()
                if key in nested_dict[sub_key]
            ]
        )

    for key, value in nested_dict.items():
        if isinstance(value, dict):
            for subkey in value.keys():
                if subkey in val.keys():
                    if val[subkey] != 0:
                        # Address the problem when one of the labels has zero value
                        value[subkey] = (
                            value[subkey] * target[subkey]["right"] / val[subkey]
                        )
                    else:
                        value[subkey] = 0
                else:
                    value[subkey] = target[subkey]["right"]
    return nested_dict


def width_determine(labels, data, pos="left"):
    """
    Calculates normalized width positions for a set of labels based on their associated data.

    This function is designed to determine width positions for plotting or graphical representation.
    It takes into account the cumulative weight of each label in the data and adjusts their positions
    accordingly. The function allows for adjusting the position of labels to either the 'left' or 'right'.

    Parameters:
    labels (list): A list of labels whose width positions are to be calculated.
    data (DataFrame): A pandas DataFrame containing the data used for calculating width positions.
                      The DataFrame should have columns corresponding to the 'pos' and 'posWeight'.
    pos (str, optional): The position of labels. It can be either 'left' or 'right'. Defaults to 'left'.

    Returns:
    defaultdict: A dictionary where each key is a label and the value is another dictionary with keys
                 'bottom', 'top', and 'pos', representing the calculated width positions.

    Note:
    The function assumes that the data DataFrame contains columns named after the value of 'pos' and
    an additional column named 'posWeight' which represents the weight of each label.
    """
    if labels is None:
        raise ValueError("The `labels` parameter cannot be None")

    if data is None:
        raise ValueError("The `data` parameter cannot be None")
    
    widths_norm = defaultdict()
    for i, label in enumerate(labels):
        myD = {}
        myD[pos] = data[data[pos] == label][pos + "Weight"].sum()
        if len(labels) != 1:
            if i == 0:
                myD["bottom"] = 0
                myD[pos] -= 0.01
                myD["top"] = myD[pos]
            elif i == len(labels) - 1:
                myD[pos] -= 0.01
                myD["bottom"] = 1 - myD[pos]
                myD["top"] = 1
            else:
                myD[pos] -= 0.02
                myD["bottom"] = widths_norm[labels[i - 1]]["top"] + 0.02
                myD["top"] = myD["bottom"] + myD[pos]
        else:
            myD["bottom"] = 0
            myD["top"] = 1
        widths_norm[label] = myD
    return widths_norm


def single_sankey(
    left: np.array,  # data on the left of the diagram
    right: np.array,  # data on the right of the diagram, len(left) == len(right)
    xpos: float = 0,  # the starting point on the x-axis
    left_weight: np.array = None,  # weights for the left labels, if None, all weights are 1
    right_weight: np.array = None,  # weights for the right labels, if None, all weights are corresponding left_weight
    colorDict: dict = None,  # input format: {'label': 'color'}
    left_labels: list = None,  # labels for the left side of the diagram. The diagram will be sorted by these labels.
    right_labels: list = None,  # labels for the right side of the diagram. The diagram will be sorted by these labels.
    ax=None,  # matplotlib axes to be drawn on
    flow: bool = True,  # if True, draw the sankey in a flow, else draw 1 vs 1 Sankey diagram for each group comparison
    sankey: bool = True,  # if True, draw the sankey diagram, else draw barplot
    width=0.5,
    alpha=0.65,
    bar_width=0.2,
    error_bar_on: bool = True,  # if True, draw error bar for each group comparison
    strip_on: bool = True,  # if True, draw strip for each group comparison
    one_sankey: bool = False,  # if True, only draw one sankey diagram
    right_color: bool = False,  # if True, each strip of the diagram will be colored according to the corresponding left labels
    align: bool = "center",  # if 'center', the diagram will be centered on each xtick,  if 'edge', the diagram will be aligned with the left edge of each xtick
):
    """
    Make a single Sankey diagram showing proportion flow from left to right
    Original code from: https://github.com/anazalea/pySankey
    Changes are added to normalize each diagram's height to be 1

    """

    # Initiating values
    if ax is None:
        ax = plt.gca()

    if left_weight is None:
        left_weight = []
    if right_weight is None:
        right_weight = []
    if left_labels is None:
        left_labels = []
    if right_labels is None:
        right_labels = []
    # Check weights
    if len(left_weight) == 0:
        left_weight = np.ones(len(left))
    if len(right_weight) == 0:
        right_weight = np.ones(len(right))

    # Create Dataframe
    if isinstance(left, pd.Series):
        left.reset_index(drop=True, inplace=True)
    if isinstance(right, pd.Series):
        right.reset_index(drop=True, inplace=True)
    dataFrame = pd.DataFrame(
        {
            "left": left,
            "right": right,
            "left_weight": left_weight,
            "right_weight": right_weight,
        },
        index=range(len(left)),
    )

    if dataFrame[["left", "right"]].isnull().any(axis=None):
        raise Exception("Sankey graph does not support null values.")

    # Identify all labels that appear 'left' or 'right'
    allLabels = pd.Series(
        np.sort(np.r_[dataFrame.left.unique(), dataFrame.right.unique()])[::-1]
    ).unique()

    # Identify left labels
    if len(left_labels) == 0:
        left_labels = pd.Series(np.sort(dataFrame.left.unique())[::-1]).unique()
    else:
        check_data_matches_labels(left_labels, dataFrame["left"], "left")

    # Identify right labels
    if len(right_labels) == 0:
        right_labels = pd.Series(np.sort(dataFrame.right.unique())[::-1]).unique()
    else:
        check_data_matches_labels(left_labels, dataFrame["right"], "right")

    # If no colorDict given, make one
    if colorDict is None:
        colorDict = {}
        palette = "hls"
        colorPalette = sns.color_palette(palette, len(allLabels))
        for i, label in enumerate(allLabels):
            colorDict[label] = colorPalette[i]
        fail_color = {0: "grey"}
        colorDict.update(fail_color)
    else:
        missing = [label for label in allLabels if label not in colorDict.keys()]
        if missing:
            msg = "The palette parameter is missing values for the following labels : "
            msg += "{}".format(", ".join(missing))
            raise ValueError(msg)

    if align not in ("center", "edge"):
        err = "{} assigned for `align` is not valid.".format(align)
        raise ValueError(err)
    if align == "center":
        try:
            leftpos = xpos - width / 2
        except TypeError as e:
            raise TypeError(
                f"the dtypes of parameters x ({xpos.dtype}) "
                f"and width ({width.dtype}) "
                f"are incompatible"
            ) from e
    else:
        leftpos = xpos

    # Combine left and right arrays to have a pandas.DataFrame in the 'long' format
    left_series = pd.Series(left, name="values").to_frame().assign(groups="left")
    right_series = pd.Series(right, name="values").to_frame().assign(groups="right")
    concatenated_df = pd.concat([left_series, right_series], ignore_index=True)

    # Determine positions of left label patches and total widths
    # We also want the height of the graph to be 1
    leftWidths_norm = defaultdict()
    for i, left_label in enumerate(left_labels):
        myD = {}
        myD["left"] = (
            dataFrame[dataFrame.left == left_label].left_weight.sum()
            / dataFrame.left_weight.sum()
        )
        if len(left_labels) != 1:
            if i == 0:
                myD["bottom"] = 0
                myD["left"] -= 0.01
                myD["top"] = myD["left"]
            elif i == len(left_labels) - 1:
                myD["left"] -= 0.01
                myD["bottom"] = 1 - myD["left"]
                myD["top"] = 1
            else:
                myD["left"] -= 0.02
                myD["bottom"] = leftWidths_norm[left_labels[i - 1]]["top"] + 0.02
                myD["top"] = myD["bottom"] + myD["left"]
                topEdge = myD["top"]
        else:
            myD["bottom"] = 0
            myD["top"] = 1
            myD["left"] = 1
        leftWidths_norm[left_label] = myD

    # Determine positions of right label patches and total widths
    rightWidths_norm = defaultdict()
    for i, right_label in enumerate(right_labels):
        myD = {}
        myD["right"] = (
            dataFrame[dataFrame.right == right_label].right_weight.sum()
            / dataFrame.right_weight.sum()
        )
        if len(right_labels) != 1:
            if i == 0:
                myD["bottom"] = 0
                myD["right"] -= 0.01
                myD["top"] = myD["right"]
            elif i == len(right_labels) - 1:
                myD["right"] -= 0.01
                myD["bottom"] = 1 - myD["right"]
                myD["top"] = 1
            else:
                myD["right"] -= 0.02
                myD["bottom"] = rightWidths_norm[right_labels[i - 1]]["top"] + 0.02
                myD["top"] = myD["bottom"] + myD["right"]
                topEdge = myD["top"]
        else:
            myD["bottom"] = 0
            myD["top"] = 1
            myD["right"] = 1
        rightWidths_norm[right_label] = myD

    # Total width of the graph
    xMax = width

    # Plot vertical bars for each label
    for left_label in left_labels:
        ax.fill_between(
            [leftpos + (-(bar_width) * xMax * 0.5), leftpos + (bar_width * xMax * 0.5)],
            2 * [leftWidths_norm[left_label]["bottom"]],
            2 * [leftWidths_norm[left_label]["top"]],
            color=colorDict[left_label],
            alpha=0.99,
        )
    if (not flow and sankey) or one_sankey:
        for right_label in right_labels:
            ax.fill_between(
                [
                    xMax + leftpos + (-bar_width * xMax * 0.5),
                    leftpos + xMax + (bar_width * xMax * 0.5),
                ],
                2 * [rightWidths_norm[right_label]["bottom"]],
                2 * [rightWidths_norm[right_label]["top"]],
                color=colorDict[right_label],
                alpha=0.99,
            )

    # Plot error bars
    if error_bar_on and strip_on:
        error_bar(
            concatenated_df,
            x="groups",
            y="values",
            ax=ax,
            offset=0,
            gap_width_percent=2,
            method="sankey_error_bar",
            pos=[leftpos, leftpos + xMax],
        )

    # Determine widths of individual strips, all widths are normalized to 1
    ns_l = defaultdict()
    ns_r = defaultdict()
    ns_l_norm = defaultdict()
    ns_r_norm = defaultdict()
    for left_label in left_labels:
        leftDict = {}
        rightDict = {}
        for right_label in right_labels:
            leftDict[right_label] = dataFrame[
                (dataFrame.left == left_label) & (dataFrame.right == right_label)
            ].left_weight.sum()

            rightDict[right_label] = dataFrame[
                (dataFrame.left == left_label) & (dataFrame.right == right_label)
            ].right_weight.sum()
        factorleft = leftWidths_norm[left_label]["left"] / sum(leftDict.values())
        leftDict_norm = {k: v * factorleft for k, v in leftDict.items()}
        ns_l_norm[left_label] = leftDict_norm
        ns_r[left_label] = rightDict

    # ns_r should be using a different way of normalization to fit the right side
    # It is normalized using the value with the same key in each sub-dictionary
    ns_r_norm = normalize_dict(ns_r, rightWidths_norm)

    # Plot strips
    if sankey and strip_on:
        for left_label, right_label in itertools.product(left_labels, right_labels):
            labelColor = left_label
            
            if right_color:
                labelColor = right_label
            
            if len(dataFrame[(dataFrame.left == left_label) & 
                        (dataFrame.right == right_label)]) > 0:
                # Create array of y values for each strip, half at left value,
                # half at right, convolve
                ys_d = np.array(
                    50 * [leftWidths_norm[left_label]["bottom"]]
                    + 50 * [rightWidths_norm[right_label]["bottom"]]
                )
                ys_d = np.convolve(ys_d, 0.05 * np.ones(20), mode="valid")
                ys_d = np.convolve(ys_d, 0.05 * np.ones(20), mode="valid")
                # to remove the array wrapping behaviour of black
                # fmt: off
                ys_u = np.array(50 * [leftWidths_norm[left_label]['bottom'] + ns_l_norm[left_label][right_label]] + \
                    50 * [rightWidths_norm[right_label]['bottom'] + ns_r_norm[left_label][right_label]])
                # fmt: on
                ys_u = np.convolve(ys_u, 0.05 * np.ones(20), mode="valid")
                ys_u = np.convolve(ys_u, 0.05 * np.ones(20), mode="valid")

                # Update bottom edges at each label so next strip starts at the right place
                leftWidths_norm[left_label]["bottom"] += ns_l_norm[left_label][right_label]
                rightWidths_norm[right_label]["bottom"] += ns_r_norm[left_label][
                    right_label
                ]
                ax.fill_between(
                    np.linspace(
                        leftpos + (bar_width * xMax * 0.5),
                        leftpos + xMax - (bar_width * xMax * 0.5),
                        len(ys_d),
                    ),
                    ys_d,
                    ys_u,
                    alpha=alpha,
                    color=colorDict[labelColor],
                    edgecolor="none",
                )


def sankeydiag(
    data: pd.DataFrame,
    xvar: str,  # x column to be plotted.
    yvar: str,  # y column to be plotted.
    temp_all_plot_groups: list,
    idx: list,
    temp_idx: list,
    left_labels: list = None,  # labels for the left side of the diagram. The diagram will be sorted by these labels.
    right_labels: list = None,  # labels for the right side of the diagram. The diagram will be sorted by these labels.
    palette: str | dict = None,
    ax=None,  # matplotlib axes to be drawn on
    flow: bool = True,  # if True, draw the sankey in a flow, else draw 1 vs 1 Sankey diagram for each group comparison
    sankey: bool = True,  # if True, draw the sankey diagram, else draw barplot
    one_sankey: bool = False,  # determined by the driver function on plotter.py, if True, draw the sankey diagram across the whole raw data axes
    width: float = 0.4,  # the width of each sankey diagram
    right_color: bool = False,  # if True, each strip of the diagram will be colored according to the corresponding left labels
    align: str = "center",  # the alignment of each sankey diagram, can be 'center' or 'left'
    alpha: float = 0.65,  # the transparency of each strip
    **kwargs,
):
    """
    Read in melted pd.DataFrame, and draw multiple sankey diagram on a single axes
    using the value in column yvar according to the value in column xvar
    left_idx in the column xvar is on the left side of each sankey diagram
    right_idx in the column xvar is on the right side of each sankey diagram

    """

    if "width" in kwargs:
        width = kwargs["width"]

    if "align" in kwargs:
        align = kwargs["align"]

    if "alpha" in kwargs:
        alpha = kwargs["alpha"]

    if "right_color" in kwargs:
        right_color = kwargs["right_color"]

    if "bar_width" in kwargs:
        bar_width = kwargs["bar_width"]

    if "sankey" in kwargs:
        sankey = kwargs["sankey"]

    if "flow" in kwargs:
        flow = kwargs["flow"]

    if ax is None:
        ax = plt.gca()

    left_idx = []
    right_idx = []
    # Design for Sankey Flow Diagram
    sankey_idx = (
        [
            (control, test)
            for i in idx
            for control, test in zip(i[:], (i[1:] + (i[0],)))
        ]
        if flow
        else temp_idx
    )
    for i in sankey_idx:
        left_idx.append(i[0])
        right_idx.append(i[1])

    if len(temp_all_plot_groups) == 2:
        one_sankey = True
        left_idx.pop()
        right_idx.pop()  # Remove the last element from two lists

    # two_col_sankey = True if proportional == True and one_sankey == False and sankey == True and flow == False else False


    allLabels = pd.Series(np.sort(data[yvar].unique())[::-1]).unique()

    # Check if all the elements in left_idx and right_idx are in xvar column
    unique_xvar = data[xvar].unique()
    if not all(elem in unique_xvar for elem in left_idx):
        raise ValueError(f"{left_idx} not found in {xvar} column")
    if not all(elem in unique_xvar for elem in right_idx):
        raise ValueError(f"{right_idx} not found in {xvar} column")

    xpos = 0

    # For baseline comparison, broadcast left_idx to the same length as right_idx
    # so that the left of sankey diagram will be the same
    # For sequential comparison, left_idx and right_idx can have anything different
    # but should have the same length
    if len(left_idx) == 1:
        broadcasted_left = np.broadcast_to(left_idx, len(right_idx))
    elif len(left_idx) != len(right_idx):
        raise ValueError(f"left_idx and right_idx should have the same length")
    else:
        broadcasted_left = left_idx

    if isinstance(palette, dict):
        if not all(key in allLabels for key in palette.keys()):
            raise ValueError(f"keys in palette should be in {yvar} column")
        plot_palette = palette
    elif isinstance(palette, str):
        plot_palette = {}
        colorPalette = sns.color_palette(palette, len(allLabels))
        for i, label in enumerate(allLabels):
            plot_palette[label] = colorPalette[i]
    else:
        plot_palette = None

    # Create a strip_on list to determine whether to draw the strip during repeated measures
    strip_on = [
        int(right not in broadcasted_left[:i]) for i, right in enumerate(right_idx)
    ]

    draw_idx = list(zip(broadcasted_left, right_idx))
    for i, (left, right) in enumerate(draw_idx):
        if not one_sankey:
            if flow:
                width = 1
                align = "edge"
                sankey = (
                    False if i == len(draw_idx) - 1 else sankey
                )  # Remove last strip in flow
            error_bar_on = (
                False if i == len(draw_idx) - 1 and flow else True
            )  # Remove last error_bar in flow
            bar_width = 0.4 if sankey == False and flow == False else bar_width
            single_sankey(
                data[data[xvar] == left][yvar],
                data[data[xvar] == right][yvar],
                xpos=xpos,
                ax=ax,
                colorDict=plot_palette,
                width=width,
                left_labels=left_labels,
                right_labels=right_labels,
                strip_on=strip_on[i],
                right_color=right_color,
                bar_width=bar_width,
                sankey=sankey,
                error_bar_on=error_bar_on,
                flow=flow,
                align=align,
                alpha=alpha,
            )
            xpos += 1
        else:
            xpos = 0
            width = 1
            if not sankey:
                bar_width = 0.5
            single_sankey(
                data[data[xvar] == left][yvar],
                data[data[xvar] == right][yvar],
                xpos=xpos,
                ax=ax,
                colorDict=plot_palette,
                width=width,
                left_labels=left_labels,
                right_labels=right_labels,
                right_color=right_color,
                bar_width=bar_width,
                sankey=sankey,
                one_sankey=one_sankey,
                flow=False,
                align="edge",
                alpha=alpha,
            )

    # Now only draw vs xticks for two-column sankey diagram
    if not one_sankey or (sankey and not flow):
        sankey_ticks = (
            [f"{left}" for left in broadcasted_left]
            if flow
            else [
                f"{left}\n v.s.\n{right}"
                for left, right in zip(broadcasted_left, right_idx)
            ]
        )
        ax.get_xaxis().set_ticks(np.arange(len(right_idx)))
        ax.get_xaxis().set_ticklabels(sankey_ticks)
    else:
        sankey_ticks = [broadcasted_left[0], right_idx[0]]
        ax.set_xticks([0, 1])
        ax.set_xticklabels(sankey_ticks)

    return left_idx, right_idx

def summary_bars_plotter(summary_bars: list, results: object, ax_to_plot: object,
                 float_contrast: bool,summary_bars_kwargs: dict, ci_type: str,
                 ticks_to_plot: list, color_col: str, swarm_colors: list, 
                 proportional: bool, is_paired: bool):
    """
    Add summary bars to the contrast plot.

    Parameters
    ----------
    summary_bars : list
        List of indices of the contrast objects to plot summary bars for.
    results : object (Dataframe)
        Dataframe of contrast object comparisons.
    ax_to_plot : object
        Matplotlib axis object to plot on.
    float_contrast : bool
        Whether the DABEST plot uses Gardner-Altman or Cummings.
    summary_bars_kwargs : dict
        Keyword arguments for the summary bars.
    ci_type : str 
        Type of confidence interval to plot.
    ticks_to_plot : list
        List of indices of the contrast objects.
    color_col : str
        Column name of the color column.
    swarm_colors : list
        List of colors used in the plot.
    proportional : bool
        Whether the data is proportional.
    is_paired : bool
        Whether the data is paired.
    """
# Begin checks        
    if not isinstance(summary_bars, list):
        raise TypeError("summary_bars must be a list of indices (ints).")
    if not all(isinstance(i, int) for i in summary_bars):
        raise TypeError("summary_bars must be a list of indices (ints).")
    if any(i >= len(results) for i in summary_bars):
        raise ValueError("Index {} chosen is out of range for the contrast objects.".format([i for i in summary_bars if i >= len(results)]))
    if float_contrast:
        raise ValueError("summary_bars cannot be used with Gardner-Altman plots.")
# End checks
    else:
        summary_xmin, summary_xmax = ax_to_plot.get_xlim()
        summary_bars_colors = [summary_bars_kwargs.get('color')]*(max(summary_bars)+1) if summary_bars_kwargs.get('color') is not None else ['black']*(max(summary_bars)+1) if color_col is not None or (proportional and is_paired) or is_paired else swarm_colors
        summary_bars_kwargs.pop('color')
        for summary_index in summary_bars:
            if ci_type == "bca":
                summary_ci_low = results.bca_low[summary_index]
                summary_ci_high = results.bca_high[summary_index]
            else:
                summary_ci_low = results.pct_low[summary_index]
                summary_ci_high = results.pct_high[summary_index]

            summary_color = summary_bars_colors[ticks_to_plot[summary_index]]

            ax_to_plot.add_patch(mpatches.Rectangle((summary_xmin,summary_ci_low),summary_xmax+1, 
            summary_ci_high-summary_ci_low, zorder=-2, color=summary_color, **summary_bars_kwargs))


def contrast_bars_plotter(results: object, ax_to_plot: object,  swarm_plot_ax: object,
                          ticks_to_plot: list, contrast_bars_kwargs: dict, color_col: str, 
                          swarm_colors: list, show_mini_meta: bool, mini_meta_delta: object, 
                          show_delta2: bool, delta_delta: object, proportional: bool, is_paired: bool):
    """
    Add contrast bars to the contrast plot.

    Parameters
    ----------
    results : object (Dataframe)
        Dataframe of contrast object comparisons.
    ax_to_plot : object
        Matplotlib axis object to plot on.
    swarm_plot_ax : object (ax)
        Matplotlib axis object of the swarm plot.
    ticks_to_plot : list
        List of indices of the contrast objects.
    contrast_bars_kwargs : dict 
        Keyword arguments for the contrast bars.
    color_col : str
        Column name of the color column.
    swarm_colors : list 
        List of colors used in the plot.
    show_mini_meta : bool   
        Whether to show the mini meta-analysis.
    mini_meta_delta : object    
        Mini meta-analysis object.
    show_delta2 : bool
        Whether to show the delta-delta.
    delta_delta : object
        delta-delta object.
    proportional : bool
        Whether the data is proportional.
    is_paired : bool
        Whether the data is paired.
    """
    contrast_means = []
    for j, tick in enumerate(ticks_to_plot):
        contrast_means.append(results.difference[j])

    contrast_bars_colors = [contrast_bars_kwargs.get('color')]*(max(ticks_to_plot)+1) if contrast_bars_kwargs.get('color') is not None else ['black']*(max(ticks_to_plot)+1) if color_col is not None or (proportional and is_paired) or is_paired else swarm_colors
    contrast_bars_kwargs.pop('color')
    for contrast_bars_x,contrast_bars_y in zip(ticks_to_plot, contrast_means):
        ax_to_plot.add_patch(mpatches.Rectangle((contrast_bars_x-0.25,0),0.5, contrast_bars_y, zorder=-1, color=contrast_bars_colors[contrast_bars_x], **contrast_bars_kwargs))

    if show_mini_meta:
        ax_to_plot.add_patch(mpatches.Rectangle((max(swarm_plot_ax.get_xticks())+2-0.25,0),0.5, mini_meta_delta.difference, zorder=-1, color='black', **contrast_bars_kwargs))

    if show_delta2:
        ax_to_plot.add_patch(mpatches.Rectangle((max(swarm_plot_ax.get_xticks())+2-0.25,0),0.5, delta_delta.difference, zorder=-1, color='black', **contrast_bars_kwargs))

def swarm_bars_plotter(plot_data: object, xvar: str, yvar: str, ax: object,
                       swarm_bars_kwargs: dict, color_col: str, swarm_colors: list, is_paired: bool):
    """
    Add bars to the raw data plot.

    Parameters
    ----------
    plot_data : object (Dataframe)
        Dataframe of the plot data.
    xvar : str
        Column name of the x variable.
    yvar : str
        Column name of the y variable.
    ax : object  
        Matplotlib axis object to plot on.
    swarm_bars_kwargs : dict
        Keyword arguments for the swarm bars.
    color_col : str
        Column name of the color column.
    swarm_colors : list
        List of colors used in the plot.
    is_paired : bool
        Whether the data is paired.
    """

    # if is_paired:
    #     swarm_bar_xlocs_adjustleft = {'right': -0.2, 'left': -0.2, 'center': -0.2}
    #     swarm_bar_xlocs_adjustright = {'right': -0.1, 'left': -0.1, 'center': -0.1}            
    # else:
    #     swarm_bar_xlocs_adjustleft = {'right': 0, 'left': -0.4, 'center': -0.2}
    #     swarm_bar_xlocs_adjustright = {'right': -0.1, 'left': -0.1, 'center': -0.1}

    if isinstance(plot_data[xvar].dtype, pd.CategoricalDtype):
        swarm_bars_order = pd.unique(plot_data[xvar]).categories
    else:
        swarm_bars_order = pd.unique(plot_data[xvar])

    swarm_means = plot_data.groupby(xvar)[yvar].mean().reindex(index=swarm_bars_order)
    swarm_bars_colors = [swarm_bars_kwargs.get('color')]*(max(swarm_bars_order)+1) if swarm_bars_kwargs.get('color') is not None else ['black']*(len(swarm_bars_order)+1) if color_col is not None or is_paired else swarm_colors
    swarm_bars_kwargs.pop('color')
    for swarm_bars_x,swarm_bars_y,c in zip(np.arange(0,len(swarm_bars_order)+1,1), swarm_means, swarm_bars_colors):
        ax.add_patch(mpatches.Rectangle((swarm_bars_x-0.25,0),
        0.5, swarm_bars_y, zorder=-1,color=c,**swarm_bars_kwargs))

def delta_text_plotter(results: object, ax_to_plot: object, swarm_plot_ax: object, ticks_to_plot: list, delta_text_kwargs: dict, color_col: str, 
                       swarm_colors: list, is_paired: bool, proportional: bool, float_contrast: bool,
                       show_mini_meta: bool, mini_meta_delta: object, show_delta2: bool, delta_delta: object):
    """
    Add text to the contrast plot.

    Parameters
    ----------
    results : object (Dataframe)
        Dataframe of contrast object comparisons.
    ax_to_plot : object
        Matplotlib axis object to plot on.
    swarm_plot_ax : object
        Matplotlib axis object of the swarm plot.
    ticks_to_plot : list
        List of indices of the contrast objects.
    delta_text_kwargs : dict
        Keyword arguments for the delta text.
    color_col : str
        Column name of the color column.
    swarm_colors : list
        List of colors used in the plot.
    is_paired : bool
        Whether the data is paired.
    proportional : bool
        Whether the data is proportional.
    float_contrast : bool
        Whether the DABEST plot uses Gardner-Altman or Cummings
    show_mini_meta : bool
        Whether to show the mini meta-analysis.
    mini_meta_delta : object
        Mini meta-analysis object.
    show_delta2 : bool
        Whether to show the delta-delta.
    delta_delta : object
        delta-delta object.
    """
    # Begin checks
    delta_text_x_location = delta_text_kwargs.get('x_location')
    if delta_text_x_location != 'right' and delta_text_x_location != 'left':
        raise ValueError("delta_text_kwargs['x_location'] must be either 'right' or 'left'.")
    if float_contrast:
        delta_text_x_location = 'left'
        delta_text_kwargs["va"] = 'bottom' if results.difference[0] >= 0 else 'top'
    delta_text_kwargs.pop('x_location')

    delta_text_colors = [delta_text_kwargs.get('color')]*(max(ticks_to_plot)+1) if delta_text_kwargs.get('color') is not None else ['black']*(max(ticks_to_plot)+1) if color_col is not None or (proportional and is_paired) or is_paired else swarm_colors
    if show_mini_meta or show_delta2: delta_text_colors.append('black')
    delta_text_kwargs.pop('color')

    total_ticks = len(ticks_to_plot) + 1 if show_mini_meta or show_delta2 else len(ticks_to_plot)

    # Collect the Y-values for the delta text
    Delta_Values = []
    for j, tick in enumerate(ticks_to_plot):
        Delta_Values.append(results.difference[j])
    if show_delta2: Delta_Values.append(delta_delta.difference)
    if show_mini_meta: Delta_Values.append(mini_meta_delta.difference)

    # Collect the X-coordinates for the delta text
    delta_text_x_coordinates = delta_text_kwargs.get('x_coordinates')

    if delta_text_x_coordinates is not None:
        if not isinstance(delta_text_x_coordinates, list):
            raise TypeError("delta_text_kwargs['x_coordinates'] must be a list of x-coordinates.")
        if len(delta_text_x_coordinates) != len(total_ticks):
            raise ValueError("delta_text_kwargs['x_coordinates'] must have the same length as the number of ticks to plot.")
    else:
        delta_text_x_coordinates = ticks_to_plot
        X_Adjust = 0.48 if delta_text_x_location == 'right' else -0.38
        delta_text_x_coordinates = [x+X_Adjust for x in delta_text_x_coordinates]
        if show_mini_meta: delta_text_x_coordinates.append(max(swarm_plot_ax.get_xticks())+2+X_Adjust)
        if show_delta2: delta_text_x_coordinates.append(max(swarm_plot_ax.get_xticks())+2-0.35)
        if show_mini_meta or show_delta2: ticks_to_plot.append(max(ticks_to_plot)+1)
    delta_text_kwargs.pop('x_coordinates')

    # Collect the Y-coordinates for the delta text
    delta_text_y_coordinates = delta_text_kwargs.get('y_coordinates')

    if delta_text_y_coordinates is not None:
        if not isinstance(delta_text_y_coordinates, list):
            raise TypeError("delta_text_kwargs['y_coordinates'] must be a list of y-coordinates.")
        if len(delta_text_y_coordinates) != len(total_ticks):
            raise ValueError("delta_text_kwargs['y_coordinates'] must have the same length as the number of ticks to plot.")
    else:
        delta_text_y_coordinates = Delta_Values

    delta_text_kwargs.pop('y_coordinates')

    # Plot the delta text
    for x,y,t,tick in zip(delta_text_x_coordinates, delta_text_y_coordinates,Delta_Values,ticks_to_plot):
        Delta_Text = np.format_float_positional(t, precision=2, sign=True, trim="k", min_digits=2)
        ax_to_plot.text(x, y, Delta_Text, color=delta_text_colors[tick], zorder=5, **delta_text_kwargs)


def DeltaDotsPlotter(plot_data, contrast_axes, delta_id_col, idx, xvar, yvar, is_paired, color_col, float_contrast, plot_palette_raw, delta_dot_kwargs):
    """
    Parameters
    ----------
    plot_data : object (Dataframe)
        Dataframe of the plot data.
    contrast_axes : object
        Matplotlib axis object to plot on.
    delta_id_col : str
        Column name of the delta id column.
    idx : list
        List of indices of the contrast objects.
    xvar : str
        Column name of the x variable.
    yvar : str
        Column name of the y variable.
    is_paired : bool
        Whether the data is paired.
    color_col : str
        Column name of the color column.
    float_contrast : bool
        Whether the DABEST plot uses Gardner-Altman or Cummings
    plot_palette_raw : list
        List of colors used in the plot.
    delta_dot_kwargs : dict
        Keyword arguments for the delta dots.
    """
    
    # Checks and initializations
    from .plot_tools import swarmplot

    if color_col is not None:
        plot_palette_deltapts = plot_palette_raw
        delta_plot_data = plot_data[[xvar, yvar, delta_id_col, color_col]]
    else:
        plot_palette_deltapts = "k"
        delta_plot_data = plot_data[[xvar, yvar, delta_id_col]]

    # TODO: to make jitter value more accurate and not just a hardcoded eyeball value
    jitter = 0.6 if float_contrast else 1 

    # Create dataframe of delta values
    final_deltas = pd.DataFrame()
    for i in idx:
        for j in i:
            if i.index(j) != 0:
                temp_df_exp = delta_plot_data[
                    delta_plot_data[xvar].str.contains(j)
                ].reset_index(drop=True)
                if is_paired == "baseline":
                    temp_df_cont = delta_plot_data[
                        delta_plot_data[xvar].str.contains(i[0])
                    ].reset_index(drop=True)
                elif is_paired == "sequential":
                    temp_df_cont = delta_plot_data[
                        delta_plot_data[xvar].str.contains(
                            i[i.index(j) - 1]
                        )
                    ].reset_index(drop=True)
                delta_df = temp_df_exp.copy()
                delta_df[yvar] = temp_df_exp[yvar] - temp_df_cont[yvar]
                final_deltas = pd.concat([final_deltas, delta_df])

    # Plot the delta dots
    swarmplot(
        data=final_deltas,
        x=xvar,
        y=yvar,
        ax=contrast_axes,
        order=None,
        hue=color_col,
        palette=plot_palette_deltapts,
        jitter=jitter,
        is_drop_gutter=True,
        gutter_limit=1,
        **delta_dot_kwargs)
    contrast_axes.legend().set_visible(False)


def slopegraph_plotter(dabest_obj, plot_data, xvar, yvar, color_col, plot_palette_raw, slopegraph_kwargs, rawdata_axes, ytick_color, temp_idx):
    
    # Pivot the long (melted) data.
    if color_col is None:
        pivot_values = [yvar]
    else:
        pivot_values = [yvar, color_col]
    pivoted_plot_data = pd.pivot(
        data=plot_data,
        index=dabest_obj.id_col,
        columns=xvar,
        values=pivot_values,
    )

    x_start = 0
    for ii, current_tuple in enumerate(temp_idx):
        current_pair = pivoted_plot_data.loc[
            :, pd.MultiIndex.from_product([pivot_values, current_tuple])
        ].dropna()
        grp_count = len(current_tuple)
        # Iterate through the data for the current tuple.
        for ID, observation in current_pair.iterrows():
            x_points = [t for t in range(x_start, x_start + grp_count)]
            y_points = observation[yvar].tolist()

            if color_col is None:
                slopegraph_kwargs["color"] = ytick_color
            else:
                color_key = observation[color_col][0]
                if isinstance(color_key, (str, np.int64, np.float64)):
                    slopegraph_kwargs["color"] = plot_palette_raw[color_key]
                    slopegraph_kwargs["label"] = color_key

            rawdata_axes.plot(x_points, y_points, **slopegraph_kwargs)

        x_start = x_start + grp_count

def plot_minimeta_or_deltadelta_violins(show_mini_meta, effectsize_df, ci_type, rawdata_axes,
                                        contrast_axes, violinplot_kwargs, halfviolin_alpha, ytick_color, 
                                        es_marker_size, group_summary_kwargs, contrast_xtick_labels, effect_size
                                        ):
    if show_mini_meta:
        mini_meta_delta = effectsize_df.mini_meta_delta
        data = mini_meta_delta.bootstraps_weighted_delta
        difference = mini_meta_delta.difference
        if ci_type == "bca":
            ci_low = mini_meta_delta.bca_low
            ci_high = mini_meta_delta.bca_high
        else:
            ci_low = mini_meta_delta.pct_low
            ci_high = mini_meta_delta.pct_high
    else:
        delta_delta = effectsize_df.delta_delta
        data = delta_delta.bootstraps_delta_delta
        difference = delta_delta.difference
        if ci_type == "bca":
            ci_low = delta_delta.bca_low
            ci_high = delta_delta.bca_high
        else:
            ci_low = delta_delta.pct_low
            ci_high = delta_delta.pct_high
    # Create the violinplot.
    # New in v0.2.6: drop negative infinities before plotting.
    position = max(rawdata_axes.get_xticks()) + 2
    v = contrast_axes.violinplot(
        data[~np.isinf(data)], positions=[position], **violinplot_kwargs
    )

    fc = "grey"

    halfviolin(v, fill_color=fc, alpha=halfviolin_alpha)

    # Plot the effect size.
    contrast_axes.plot(
        [position],
        difference,
        marker="o",
        color=ytick_color,
        markersize=es_marker_size,
    )
    # Plot the confidence interval.
    contrast_axes.plot(
        [position, position],
        [ci_low, ci_high],
        linestyle="-",
        color=ytick_color,
        linewidth=group_summary_kwargs["lw"],
    )
    if show_mini_meta:
        contrast_xtick_labels.extend(["", "Weighted delta"])
    elif effect_size == "delta_g":
        contrast_xtick_labels.extend(["", "deltas' g"])
    else:
        contrast_xtick_labels.extend(["", "delta-delta"])
    
    return contrast_xtick_labels


def effect_size_curve_plotter(ticks_to_plot, results, ci_type, contrast_axes, violinplot_kwargs, halfviolin_alpha, 
                              ytick_color, es_marker_size, group_summary_kwargs, contrast_xtick_labels, 
                              bootstraps_color_by_group, plot_palette_contrast):
    for j, tick in enumerate(ticks_to_plot):
        current_group = results.test[j]
        current_control = results.control[j]
        current_bootstrap = results.bootstraps[j]
        current_effsize = results.difference[j]
        if ci_type == "bca":
            current_ci_low = results.bca_low[j]
            current_ci_high = results.bca_high[j]
        else:
            current_ci_low = results.pct_low[j]
            current_ci_high = results.pct_high[j]

        # Create the violinplot.
        # New in v0.2.6: drop negative infinities before plotting.
        v = contrast_axes.violinplot(
            current_bootstrap[~np.isinf(current_bootstrap)],
            positions=[tick],
            **violinplot_kwargs
        )
        # Turn the violinplot into half, and color it the same as the swarmplot.
        # Do this only if the color column is not specified.
        # Ideally, the alpha (transparency) fo the violin plot should be
        # less than one so the effect size and CIs are visible.
        if bootstraps_color_by_group:
            fc = plot_palette_contrast[current_group]
        else:
            fc = "grey"

        halfviolin(v, fill_color=fc, alpha=halfviolin_alpha)

        # Plot the effect size.
        contrast_axes.plot(
            [tick],
            current_effsize,
            marker="o",
            color=ytick_color,
            markersize=es_marker_size,
        )

        # Plot the confidence interval.
        contrast_axes.plot(
            [tick, tick],
            [current_ci_low, current_ci_high],
            linestyle="-",
            color=ytick_color,
            linewidth=group_summary_kwargs["lw"],
        )

        contrast_xtick_labels.append(
            "{}\nminus\n{}".format(current_group, current_control)
        )
    return current_group, current_control, current_effsize


def grid_key_WIP(is_paired, idx, all_plot_groups, gridkey_rows, rawdata_axes, contrast_axes,
                 plot_data, xvar, yvar, results, show_delta2, show_mini_meta, float_contrast, plot_kwargs,):
    
    gridkey_show_Ns=plot_kwargs["gridkey_show_Ns"]
    gridkey_show_es=plot_kwargs["gridkey_show_es"]
    gridkey_merge_pairs=plot_kwargs["gridkey_merge_pairs"]
    
    # Raise error if there are more than 2 items in any idx and gridkey_merge_pairs is True and is_paired is not None
    if gridkey_merge_pairs and is_paired is not None:
        for i in idx:
            if len(i) > 2:
                warnings.warn(
                    "gridkey_merge_pairs=True only works if all idx in tuples have only two items. gridkey_merge_pairs has automatically been set to False"
                )
                gridkey_merge_pairs = False
                break
    elif gridkey_merge_pairs and is_paired is None:
        warnings.warn(
            "gridkey_merge_pairs=True is only applicable for paired data."
        )
        gridkey_merge_pairs = False

    # Checks for gridkey_merge_pairs and is_paired; if both are true, "merges" the gridkey per pair
    if gridkey_merge_pairs and is_paired is not None:
        groups_for_gridkey = []
        for i in idx:
            groups_for_gridkey.append(i[1])
    else:
        groups_for_gridkey = all_plot_groups

    # raise errors if gridkey_rows is not a list, or if the list is empty
    if isinstance(gridkey_rows, list) is False:
        raise TypeError("gridkey_rows must be a list.")
    elif len(gridkey_rows) == 0:
        warnings.warn("gridkey_rows is an empty list.")

    # raise Warning if an item in gridkey_rows is not contained in any idx
    for i in gridkey_rows:
        in_idx = 0
        for j in groups_for_gridkey:
            if i in j:
                in_idx += 1
        if in_idx == 0:
            if is_paired is not None:
                warnings.warn(
                    i
                    + " is not in any idx. Please check. Alternatively, merging gridkey pairs may not be suitable for your data; try passing gridkey_merge_pairs=False."
                )
            else:
                warnings.warn(i + " is not in any idx. Please check.")

    # Populate table: checks if idx for each column contains rowlabel name
    # IF so, marks that element as present w black dot, or space if not present
    table_cellcols = []
    for i in gridkey_rows:
        thisrow = []
        for q in groups_for_gridkey:
            if str(i) in q:
                thisrow.append("\u25CF")
            else:
                thisrow.append("")
        table_cellcols.append(thisrow)

    # Adds a row for Ns with the Ns values
    if gridkey_show_Ns:
        gridkey_rows.append("Ns")
        list_of_Ns = []
        for i in groups_for_gridkey:
            list_of_Ns.append(str(plot_data.groupby(xvar).count()[yvar].loc[i]))
        table_cellcols.append(list_of_Ns)

    # Adds a row for effectsizes with effectsize values
    if gridkey_show_es:
        gridkey_rows.append("\u0394")
        effsize_list = []
        results_list = results.test.to_list()

        # get the effect size, append + or -, 2 dec places
        for i in enumerate(groups_for_gridkey):
            if i[1] in results_list:
                curr_esval = results.loc[results["test"] == i[1]][
                    "difference"
                ].iloc[0]
                curr_esval_str = np.format_float_positional(
                    curr_esval,
                    precision=2,
                    sign=True,
                    trim="k",
                    min_digits=2,
                )
                effsize_list.append(curr_esval_str)
            else:
                effsize_list.append("-")

        table_cellcols.append(effsize_list)

    # If Gardner-Altman plot, plot on raw data and not contrast axes
    if float_contrast:
        axes_ploton = rawdata_axes
    else:
        axes_ploton = contrast_axes

    # Account for extended x axis in case of show_delta2 or show_mini_meta
    x_groups_for_width = len(groups_for_gridkey)
    if show_delta2 or show_mini_meta:
        x_groups_for_width += 2
    gridkey_width = len(groups_for_gridkey) / x_groups_for_width

    gridkey = axes_ploton.table(
        cellText=table_cellcols,
        rowLabels=gridkey_rows,
        cellLoc="center",
        bbox=[
            0,
            -len(gridkey_rows) * 0.1 - 0.05,
            gridkey_width,
            len(gridkey_rows) * 0.1,
        ],
        **{"alpha": 0.5}
    )

    # modifies row label cells
    for cell in gridkey._cells:
        if cell[1] == -1:
            gridkey._cells[cell].visible_edges = "open"
            gridkey._cells[cell].set_text_props(**{"ha": "right"})

    # turns off both x axes
    rawdata_axes.get_xaxis().set_visible(False)
    contrast_axes.get_xaxis().set_visible(False)

def barplotter(xvar, yvar, all_plot_groups, rawdata_axes, plot_data, bar_color, plot_palette_bar, plot_kwargs, barplot_kwargs):
    # Plot the raw data as a barplot.
    bar1_df = pd.DataFrame(
        {xvar: all_plot_groups, "proportion": np.ones(len(all_plot_groups))}
    )
    bar1 = sns.barplot(
        data=bar1_df,
        x=xvar,
        y="proportion",
        ax=rawdata_axes,
        order=all_plot_groups,
        linewidth=2,
        facecolor=(1, 1, 1, 0),
        edgecolor=bar_color,
        zorder=1,
    )
    bar2 = sns.barplot(
        data=plot_data,
        x=xvar,
        y=yvar,
        ax=rawdata_axes,
        order=all_plot_groups,
        palette=plot_palette_bar,
        zorder=1,
        **barplot_kwargs
    )
    # adjust the width of bars
    bar_width = plot_kwargs["bar_width"]
    for bar in bar1.patches:
        x = bar.get_x()
        width = bar.get_width()
        centre = x + width / 2.0
        bar.set_x(centre - bar_width / 2.0)
        bar.set_width(bar_width)
    ...

In [None]:
# | export
def swarmplot(
    data: pd.DataFrame,
    x: str,
    y: str,
    ax: axes.Subplot,
    order: List = None,
    hue: str = None,
    palette: Union[Iterable, str] = "black",
    zorder: float = 1,
    size: float = 5,
    side: str = "center",
    jitter: float = 1,
    is_drop_gutter: bool = True,
    gutter_limit: float = 0.5,
    **kwargs,
):
    """
    API to plot a swarm plot.

    Parameters
    ----------
    data : pd.DataFrame
        The input data as a pandas DataFrame.
    x : str
        The column in the DataFrame to be used as the x-axis.
    y : str
        The column in the DataFrame to be used as the y-axis.
    ax : axes._subplots.Subplot | axes._axes.Axes
        Matplotlib AxesSubplot object for which the plot would be drawn on. Default is None.
    order : List
        The order in which x-axis categories should be displayed. Default is None.
    hue : str
        The column in the DataFrame that determines the grouping for color.
        If None (by default), it assumes that it is being grouped by x.
    palette : Union[Iterable, str]
        The color palette to be used for plotting. Default is "black".
    zorder : int | float
        The z-order for drawing the swarm plot wrt other matplotlib drawings. Default is 1.
    dot_size : int | float
        The size of the markers in the swarm plot. Default is 20.
    side : str
        The side on which points are swarmed ("center", "left", or "right"). Default is "center".
    jitter : int | float
        Determines the distance between points. Default is 1.
    is_drop_gutter : bool
        If True, drop points that hit the gutters; otherwise, readjust them.
    gutter_limit : int | float
        The limit for points hitting the gutters.
    **kwargs:
        Additional keyword arguments to be passed to the swarm plot.

    Returns
    -------
    axes._subplots.Subplot | axes._axes.Axes
        Matplotlib AxesSubplot object for which the swarm plot has been drawn on.
    """
    s = SwarmPlot(data, x, y, ax, order, hue, palette, zorder, size, side, jitter)
    ax = s.plot(is_drop_gutter, gutter_limit, ax, **kwargs)
    return ax


class SwarmPlot:
    def __init__(
        self,
        data: pd.DataFrame,
        x: str,
        y: str,
        ax: axes.Subplot,
        order: List = None,
        hue: str = None,
        palette: Union[Iterable, str] = "black",
        zorder: float = 1,
        size: float = 5,
        side: str = "center",
        jitter: float = 1,
    ):
        """
        Initialize a SwarmPlot instance.

        Parameters
        ----------
        data : pd.DataFrame
            The input data as a pandas DataFrame.
        x : str
            The column in the DataFrame to be used as the x-axis.
        y : str
            The column in the DataFrame to be used as the y-axis.
        ax : axes.Subplot
            Matplotlib AxesSubplot object for which the plot would be drawn on.
        order : List
            The order in which x-axis categories should be displayed. Default is None.
        hue : str
            The column in the DataFrame that determines the grouping for color.
            If None (by default), it assumes that it is being grouped by x.
        palette : Union[Iterable, str]
            The color palette to be used for plotting. Default is "black".
        zorder : int | float
            The z-order for drawing the swarm plot wrt other matplotlib drawings. Default is 1.
        dot_size : int | float
            The size of the markers in the swarm plot. Default is 20.
        side : str
            The side on which points are swarmed ("center", "left", or "right"). Default is "center".
        jitter : int | float
            Determines the distance between points. Default is 1.

        Returns
        -------
        None
        """
        self.__x = x
        self.__y = y
        self.__order = order
        self.__hue = hue
        self.__zorder = zorder
        self.__palette = palette
        self.__jitter = jitter

        # Input validation
        self._check_errors(data, ax, size, side)

        self.__size = size * 4
        self.__side = side.lower()
        self.__data = data
        self.__color_col = self.__x if self.__hue is None else self.__hue

        # Generate default values
        if order is None:
            self.__order = self._generate_order()

        # Reformatting
        if not isinstance(self.__palette, dict):
            self.__palette = self._format_palette(self.__palette)
        data_copy = data.copy(deep=True)
        if not isinstance(self.__data[self.__x].dtype, pd.CategoricalDtype):
            # make x column into CategoricalDType to sort by
            data_copy[self.__x] = data_copy[self.__x].astype(
                CategoricalDtype(categories=self.__order, ordered=True)
            )
        data_copy.sort_values(by=[self.__x, self.__y], inplace=True)
        self.__data_copy = data_copy

        x_vals = range(len(self.__order))
        y_vals = self.__data_copy[self.__y]

        x_min = min(x_vals)
        x_max = max(x_vals)
        ax.set_xlim(left=x_min - 0.5, right=x_max + 0.5)

        y_range = max(y_vals) - min(y_vals)
        y_min = min(y_vals) - 0.05 * y_range
        y_max = max(y_vals) + 0.05 * y_range

        # ylim is set manually to override Axes.autoscale if it hasn't already been scaled at least once
        if ax.get_autoscaley_on():
            ax.set_ylim(bottom=y_min, top=y_max)

        figw, figh = ax.get_figure().get_size_inches()
        w = (ax.get_position().xmax - ax.get_position().xmin) * figw
        h = (ax.get_position().ymax - ax.get_position().ymin) * figh
        ax_xspan = ax.get_xlim()[1] - ax.get_xlim()[0]
        ax_yspan = ax.get_ylim()[1] - ax.get_ylim()[0]

        # increases jitter distance based on number of swarms that is going to be drawn
        jitter = jitter * (1 + 0.05 * (math.log(ax_xspan)))

        gsize = (
            math.sqrt(self.__size) * 1.0 / (70 / jitter) * ax_xspan * 1.0 / (w * 0.8)
        )
        dsize = (
            math.sqrt(self.__size) * 1.0 / (70 / jitter) * ax_yspan * 1.0 / (h * 0.8)
        )
        self.__gsize = gsize
        self.__dsize = dsize

    def _check_errors(
        self, data: pd.DataFrame, ax: axes.Subplot, size: float, side: str
    ) -> None:
        """
        Check the validity of input parameters. Raises exceptions if detected.

        Parameters
        ----------
        data : pd.Dataframe
            Input data used for generation of the swarmplot.
        ax : axes.Subplot
            Matplotlib AxesSubplot object for which the plot would be drawn on.
        size : int | float
            scalar value determining size of dots of the swarmplot.
        side: str
            The side on which points are swarmed ("center", "left", or "right"). Default is "center".

        Returns
        -------
        None
        """
        # Type enforcement
        if not isinstance(data, pd.DataFrame):
            raise ValueError("`data` must be a Pandas Dataframe.")
        if not isinstance(ax, (axes._subplots.Subplot, axes._axes.Axes)):
            raise ValueError(
                f"`ax` must be a Matplotlib AxesSubplot. The current `ax` is a {type(ax)}"
            )
        if not isinstance(size, (int, float)):
            raise ValueError("`size` must be a scalar or float.")
        if not isinstance(side, str):
            raise ValueError(
                "Invalid `side`. Must be one of 'center', 'right', or 'left'."
            )
        if not isinstance(self.__x, str):
            raise ValueError("`x` must be a string.")
        if not isinstance(self.__y, str):
            raise ValueError("`y` must be a string.")
        if not isinstance(self.__zorder, (int, float)):
            raise ValueError("`zorder` must be a scalar or float.")
        if not isinstance(self.__jitter, (int, float)):
            raise ValueError("`jitter` must be a scalar or float.")
        if not isinstance(self.__palette, (str, Iterable)):
            raise ValueError("`palette` must be either a string indicating a color name or an Iterable.")
        if self.__hue is not None and not isinstance(self.__hue, str):
            raise ValueError("`hue` must be either a string or None.")
        if self.__order is not None and not isinstance(self.__order, Iterable):
            raise ValueError("`order` must be either an Iterable or None.")

        # More thorough input validation checks
        if self.__x not in data.columns:
            err = "{0} is not a column in `data`.".format(self.__x)
            raise IndexError(err)
        if self.__y not in data.columns:
            err = "{0} is not a column in `data`.".format(self.__y)
            raise IndexError(err)
        if self.__hue is not None and self.__hue not in data.columns:
            err = "{0} is not a column in `data`.".format(self.__hue)
            raise IndexError(err)

        color_col = self.__x if self.__hue is None else self.__hue
        if self.__order is not None:
            for group_i in self.__order:
                if group_i not in pd.unique(data[self.__x]):
                    err = "{0} in `order` is not in the '{1}' column of `data`.".format(
                        group_i, self.__x
                    )
                    raise IndexError(err)

        if isinstance(self.__palette, str) and self.__palette.strip() == "":
            err = "`palette` cannot be an empty string. It must be either a string indicating a color name or an Iterable."
            raise ValueError(err)
        if isinstance(self.__palette, dict):
            # TODO: to add detection of when dict length is less than size of unique_items
            for group_i, color_i in self.__palette.items():
                if group_i not in pd.unique(data[color_col]):
                    err = (
                        "{0} in `palette` is not in the '{1}' column of `data`.".format(
                            group_i, color_col
                        )
                    )
                    raise IndexError(err)
                if isinstance(color_i, str) and color_i.strip() == "":
                    err = "The color mapping for {0} in `palette` is an empty string. It must contain a color name.".format(group_i)
                    raise ValueError(err) 

        if side.lower() not in ["center", "right", "left"]:
            raise ValueError(
                "Invalid `side`. Must be one of 'center', 'right', or 'left'."
            )

        return None

    def _generate_order(self) -> List:
        """
        Generates order value that determines the order in which x-axis categories should be displayed.

        Parameters
        ----------
        None

        Returns
        -------
        List:
            contains the order in which the x-axis categories should be displayed.
        """
        if isinstance(self.__data[self.__x].dtype, pd.CategoricalDtype):
            order = pd.unique(self.__data[self.__x]).categories.tolist()
        else:
            order = pd.unique(self.__data[self.__x]).tolist()

        return order

    def _format_palette(self, palette: Union[str, List, Tuple]) -> Dict:
        """
        Reformats palette into appropriate Dictionary form for swarm plot

        Parameters
        ----------
        palette: str | List | Tuple
            The color palette used for the swarm plot. Conventions are based on Matplotlib color
            specifications.

            Could be a singular string value - in which case, would be a singular color name.
            In the case of a List or Tuple - it could be a Sequence of color names or RGB(A) values.

        Returns
        -------
        Dict:
            Dictionary mapping unique groupings in the color column (of the data used for the swarm plot)
            to a color name (str) or a RGB(A) value (Tuple[float, float, float] | List[float, float, float]).
        """
        reformatted_palette = dict()
        groups = pd.unique(self.__data[self.__color_col]).tolist()

        if isinstance(palette, str):
            for group_i in groups:
                reformatted_palette[group_i] = palette
        if isinstance(palette, (list, tuple)):
            if len(groups) != len(palette):
                err = (
                    "unique values in '{0}' column in `data` "
                    "and `palette` do not have the same length. Number of unique values is {1} "
                    "while length of palette is {2}. The assignment of the colors in the "
                    "palette will be cycled."
                ).format(self.__color_col, len(groups), len(palette))
                warnings.warn(err)
            for i, group_i in enumerate(groups):
                reformatted_palette[group_i] = palette[i % len(palette)]

        return reformatted_palette

    def _swarm(
        self, values: Iterable[float], gsize: float, dsize: float, side: str
    ) -> pd.Series:
        """
        Perform the swarm algorithm to position points without overlap.

        Parameters
        ----------
        values : Iterable[int | float]
            The values to be plotted.
        gsize : int | float
            The size of the gap between points.
        dsize : int | float
            The size of the markers.
        side : str
            The side on which points are swarmed ("center", "left", or "right").

        Returns
        -------
        pd.Series:
            The x-offset values for the swarm plot.
        """
        # Input validation
        if not isinstance(values, Iterable):
            raise ValueError("`values` must be an Iterable")
        if not isinstance(gsize, (int, float)):
            raise ValueError("`gsize` must be a scalar or float.")
        if not isinstance(dsize, (int, float)):
            raise ValueError("`dsize` must be a scalar or float.")

        # Sorting algorithm based off of: https://github.com/mgymrek/pybeeswarm
        points_data = pd.DataFrame(
            {"y": [yval * 1.0 / dsize for yval in values], "x": [0] * len(values)}
        )
        for i in range(1, points_data.shape[0]):
            y_i = points_data["y"].values[i]
            points_placed = points_data[0:i]
            is_points_overlap = (
                abs(y_i - points_placed["y"]) < 1
            )  # Checks if y_i is overlapping with any points already placed
            if any(is_points_overlap):
                points_placed = points_placed[is_points_overlap]
                x_offsets = points_placed["y"].apply(
                    lambda y_j: math.sqrt(1 - (y_i - y_j) ** 2)
                )
                if side == "center":
                    potential_x_offsets = pd.Series(
                        [0]
                        + (points_placed["x"] + x_offsets).tolist()
                        + (points_placed["x"] - x_offsets).tolist()
                    )
                if side == "right":
                    potential_x_offsets = pd.Series(
                        [0] + (points_placed["x"] + x_offsets).tolist()
                    )
                if side == "left":
                    potential_x_offsets = pd.Series(
                        [0] + (points_placed["x"] - x_offsets).tolist()
                    )
                bad_x_offsets = []
                for x_i in potential_x_offsets:
                    dists = (y_i - points_placed["y"]) ** 2 + (
                        x_i - points_placed["x"]
                    ) ** 2
                    if any([item < 0.999 for item in dists]):
                        bad_x_offsets.append(True)
                    else:
                        bad_x_offsets.append(False)
                potential_x_offsets[bad_x_offsets] = np.infty
                abs_potential_x_offsets = [abs(_) for _ in potential_x_offsets]
                valid_x_offset = potential_x_offsets[
                    abs_potential_x_offsets.index(min(abs_potential_x_offsets))
                ]
                points_data.loc[i, "x"] = valid_x_offset
            else:
                points_data.loc[i, "x"] = 0

        points_data.loc[np.isnan(points_data["y"]), "x"] = np.nan

        return points_data["x"] * gsize

    def _adjust_gutter_points(
        self,
        points_data: pd.DataFrame,
        x_position: float,
        is_drop_gutter: bool,
        gutter_limit: float,
        value_column: str,
    ) -> pd.DataFrame:
        """
        Adjust points that hit the gutters or drop them based on the provided conditions.

        Parameters
        ----------
        points_data: pd.DataFrame
            Data containing coordinates of points for the swarm plot.
        x_position: int | float
            X-coordinate of the center of a singular swarm group of the swarm plot
        is_drop_gutter : bool
            If True, drop points that hit the gutters; otherwise, readjust them.
        gutter_limit : int | float
            The limit for points hitting the gutters.
        value_column : str
            column in points_data that contains the coordinates for the points in the axis against the gutter

        Returns
        -------
        pd.DataFrame:
            DataFrame with adjusted points based on the gutter limit.
        """
        if self.__side == "center":
            gutter_limit = gutter_limit / 2

        hit_gutter = abs(points_data[value_column] - x_position) >= gutter_limit
        total_num_of_points = points_data.shape[0]
        num_of_points_hit_gutter = points_data[hit_gutter].shape[0]
        if any(hit_gutter):
            if is_drop_gutter:
                # Drop points that hit gutter
                points_data.drop(points_data[hit_gutter].index.to_list(), inplace=True)
                err = (
                    "{0:.1%} of the points cannot be placed. "
                    "You might want to decrease the size of the markers."
                ).format(num_of_points_hit_gutter / total_num_of_points)
                warnings.warn(err)
            else:
                for i in points_data[hit_gutter].index:
                    points_data.loc[i, value_column] = np.sign(
                        points_data.loc[i, value_column]
                    ) * (x_position + gutter_limit)

        return points_data

    def plot(
        self, is_drop_gutter: bool, gutter_limit: float, ax: axes.Subplot, **kwargs
    ) -> axes.Subplot:
        """
        Generate a swarm plot.

        Parameters
        ----------
        is_drop_gutter : bool
            If True, drop points that hit the gutters; otherwise, readjust them.
        gutter_limit : int | float
            The limit for points hitting the gutters.
        ax : axes.Subplot
            The matplotlib figure object to which the swarm plot will be added.
        **kwargs:
            Additional keyword arguments to be passed to the scatter plot.

        Returns
        -------
        axes.Subplot:
            The matplotlib figure containing the swarm plot.
        """
        # Input validation
        if not isinstance(is_drop_gutter, bool):
            raise ValueError("`is_drop_gutter` must be a boolean.")
        if not isinstance(gutter_limit, (int, float)):
            raise ValueError("`gutter_limit` must be a scalar or float.")

        # Assumptions are that self.__data_copy is already sorted according to self.__order
        x_position = (
            0  # x-coordinate of center of each individual swarm of the swarm plot
        )
        x_tick_tabels = []
        for group_i, values_i in self.__data_copy.groupby(self.__x):
            x_new = []
            values_i_y = values_i[self.__y]
            x_offset = self._swarm(
                values=values_i_y,
                gsize=self.__gsize,
                dsize=self.__dsize,
                side=self.__side,
            )
            x_new = [
                x_position + offset for offset in x_offset
            ]  # apply x-offsets based on _swarm algo
            values_i["x_new"] = x_new
            values_i = self._adjust_gutter_points(
                values_i, x_position, is_drop_gutter, gutter_limit, "x_new"
            )
            x_tick_tabels.extend([group_i])
            x_position = x_position + 1

            if values_i.empty:
                ax.scatter(
                    values_i["x_new"],
                    values_i[self.__y],
                    s=self.__size,
                    zorder=self.__zorder,
                    **kwargs,
                )
                continue

            if self.__hue is not None:
                # color swarms based on `hue` column
                cmap_values, index = np.unique(
                    values_i[self.__hue], return_inverse=True
                )
                cmap = []
                for cmap_group_i in cmap_values:
                    cmap.append(self.__palette[cmap_group_i])
                cmap = ListedColormap(cmap)
                ax.scatter(
                    values_i["x_new"],
                    values_i[self.__y],
                    s=self.__size,
                    c=index,
                    cmap=cmap,
                    zorder=self.__zorder,
                    edgecolor="face",
                    **kwargs,
                )
            else:
                # color swarms based on `x` column
                ax.scatter(
                    values_i["x_new"],
                    values_i[self.__y],
                    s=self.__size,
                    c=self.__palette[group_i],
                    zorder=self.__zorder,
                    edgecolor="face",
                    **kwargs,
                )

        ax.get_xaxis().set_ticks(np.arange(x_position))
        ax.get_xaxis().set_ticklabels(x_tick_tabels)

        return ax