# Forest plot

> Creating forest plots from contrast objects.

- order: 4

In [None]:
#| default_exp forest_plot

In [None]:
#| hide
from __future__ import annotations

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev
nbdev.nbdev_export()

In [None]:
#| hide
import dabest

In [None]:
#| export
import matplotlib.pyplot as plt
# %matplotlib inline
import seaborn as sns
from typing import List, Optional, Union


In [None]:
#| export
def load_plot_data(
    contrasts: List, effect_size: str = "mean_diff", contrast_type: str = "delta2"
) -> List:
    """
    Loads plot data based on specified effect size and contrast type.

    Parameters
    ----------
    contrasts : List
        List of contrast objects.
    effect_size: str
        Type of effect size ('mean_diff', 'median_diff', etc.).
    contrast_type: str
        Type of contrast ('delta2', 'mini_meta').

    Returns
    -------
    List: Contrast plot data based on specified parameters.
    """
    effect_attr_map = {
        "mean_diff": "mean_diff",
        "median_diff": "median_diff",
        "cliffs_delta": "cliffs_delta",
        "cohens_d": "cohens_d",
        "hedges_g": "hedges_g",
        "delta_g": "delta_g"
    }

    contrast_attr_map = {"delta2": "delta_delta", "mini_meta": "mini_meta_delta"}

    effect_attr = effect_attr_map.get(effect_size)
    contrast_attr = contrast_attr_map.get(contrast_type)

    if not effect_attr:
        raise ValueError(f"Invalid effect_size: {effect_size}") 
    if not contrast_attr:
        raise ValueError(f"Invalid contrast_type: {contrast_type}. Available options: [`delta2`, `mini_meta`]")

    return [
        getattr(getattr(contrast, effect_attr), contrast_attr) for contrast in contrasts
    ]


def extract_plot_data(contrast_plot_data, contrast_type):
    """Extracts bootstrap, difference, and confidence intervals based on contrast labels."""
    if contrast_type == "mini_meta":
        attribute_suffix = "weighted_delta"
    else:
        attribute_suffix = "delta_delta"

    bootstraps = [
        getattr(result, f"bootstraps_{attribute_suffix}")
        for result in contrast_plot_data
    ]
    
    differences = [result.difference for result in contrast_plot_data]
    bcalows = [result.bca_low for result in contrast_plot_data]
    bcahighs = [result.bca_high for result in contrast_plot_data]
    
    return bootstraps, differences, bcalows, bcahighs

def map_effect_attribute(attribute_key):
    # Check if the attribute key exists in the dictionary
    effect_attr_map = {
        "mean_diff": "Mean Difference",
        "median_diff": "Median Difference",
        "cliffs_delta": "Cliffs Delta",
        "cohens_d": "Cohens d",
        "hedges_g": "Hedges g",
        "delta_g": "Delta g"
    }
    if attribute_key in effect_attr_map:
        return effect_attr_map[attribute_key]
    else:
        raise TypeError("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`.")  # Return a default value or message if the key is not found

def forest_plot(
    contrasts: List,
    selected_indices: Optional[List] = None,
    contrast_type: str = "delta2",
    effect_size: str = "mean_diff",
    contrast_labels: List[str] = None,
    ylabel: str = "effect size",
    plot_elements_to_extract: Optional[List] = None,
    title: str = "ΔΔ Forest",
    custom_palette: Optional[Union[dict, list, str]] = None,
    fontsize: int = 12,
    title_font_size: int =16,
    violin_kwargs: Optional[dict] = None,
    marker_size: int = 20,
    ci_line_width: float = 2.5,
    desat_violin: float = 1,
    remove_spines: bool = True,
    ax: Optional[plt.Axes] = None,
    additional_plotting_kwargs: Optional[dict] = None,
    rotation_for_xlabels: int = 45,
    alpha_violin_plot: float = 0.8,
    horizontal: bool = False  # New argument for horizontal orientation
)-> plt.Figure:
    """  
    Custom function that generates a forest plot from given contrast objects, suitable for a range of data analysis types, including those from packages like DABEST-python.

    Parameters
    ----------
    contrasts : List
        List of contrast objects.
    selected_indices : Optional[List], default=None
        Indices of specific contrasts to plot, if not plotting all.
    analysis_type : str
         the type of analysis (e.g., 'delta2', 'mini_meta').
    effect_size : str
        Type of effect size to plot (e.g., 'mean_diff', 'median_diff', `cliffs_delta`,`cohens_d``, and `hedges_g`).
    contrast_labels : List[str]
        Labels for each contrast.
    ylabel : str
        Label for the y-axis, describing the plotted data or effect size.
    plot_elements_to_extract : Optional[List], default=None
        Elements to extract for detailed plot customization.
    title : str
        Plot title, summarizing the visualized data.
    ylim : Tuple[float, float]
        Limits for the y-axis.
    custom_palette : Optional[Union[dict, list, str]], default=None
        Custom color palette for the plot.
    fontsize : int
        Font size for text elements in the plot.
    title_font_size: int =16
        Font size for text of plot title.
    violin_kwargs : Optional[dict], default=None
        Additional arguments for violin plot customization.
    marker_size : int
        Marker size for plotting mean differences or effect sizes.
    ci_line_width : float
        Width of confidence interval lines.
    remove_spines : bool, default=False
        If True, removes top and right plot spines.
    ax : Optional[plt.Axes], default=None
        Matplotlib Axes object for the plot; creates new if None.
        additional_plotting_kwargs : Optional[dict], default=None
        Further customization arguments for the plot.
    rotation_for_xlabels : int, default=0
        Rotation angle for x-axis labels, improving readability.
    alpha_violin_plot : float, default=1.0
        Transparency level for violin plots.

    Returns
    -------
    plt.Figure
        The matplotlib figure object with the generated forest plot.
    """
    from .plot_tools import halfviolin

    # Validate inputs
    if contrasts is None:
        raise ValueError("The `contrasts` parameter cannot be None")
    
    if not isinstance(contrasts, list) or not contrasts:
        raise ValueError("The `contrasts` argument must be a non-empty list.")
    
    if selected_indices is not None and not isinstance(selected_indices, (list, type(None))):
        raise TypeError("The `selected_indices` must be a list of integers or `None`.")
    
    # For the 'contrast_type' parameter
    if not isinstance(contrast_type, str):
        raise TypeError("The `contrast_type` argument must be a string. Please choose from `delta2` and `mini_meta`.")

    # For the 'effect_size' parameter
    if not isinstance(effect_size, str):
        raise TypeError("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`, `median_diff`, `cliffs_delta`, `cohens_d`, and `hedges_g`.")
    
    if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):
        raise TypeError("The `contrast_labels` must be a list of strings or `None`.")
    
    if contrast_labels is not None and len(contrast_labels) != len(contrasts):
        raise ValueError("`contrast_labels` must match the number of `contrasts` if provided.")
    
    if not isinstance(ylabel, str):
        raise TypeError("The `ylabel` argument must be a string.")
    
    if custom_palette is not None and not isinstance(custom_palette, (dict, list, str, type(None))):
        raise TypeError("The `custom_palette` must be either a dictionary, list, string, or `None`.")
    
    if not isinstance(fontsize, (int, float)):
        raise TypeError("`fontsize` must be an integer or float.")
    
    if not isinstance(marker_size, (int, float)) or marker_size <= 0:
        raise TypeError("`marker_size` must be a positive integer or float.")
    
    if not isinstance(ci_line_width, (int, float)) or ci_line_width <= 0:
        raise TypeError("`ci_line_width` must be a positive integer or float.")
    
    if not isinstance(remove_spines, bool):
        raise TypeError("`remove_spines` must be a boolean value.")
    
    if ax is not None and not isinstance(ax, plt.Axes):
        raise TypeError("`ax` must be a `matplotlib.axes.Axes` instance or `None`.")
    
    if not isinstance(rotation_for_xlabels, (int, float)) or not 0 <= rotation_for_xlabels <= 360:
        raise TypeError("`rotation_for_xlabels` must be an integer or float between 0 and 360.")
    
    if not isinstance(alpha_violin_plot, float) or not 0 <= alpha_violin_plot <= 1:
        raise TypeError("`alpha_violin_plot` must be a float between 0 and 1.")
    
    if not isinstance(horizontal, bool):
        raise TypeError("`horizontal` must be a boolean value.")

    if (effect_size and isinstance(effect_size, str)):
        ylabel = map_effect_attribute(effect_size)
    # Load plot data
    contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)

    # Extract data for plotting
    bootstraps, differences, bcalows, bcahighs = extract_plot_data(
        contrast_plot_data, contrast_type
    )
    # Adjust figure size based on orientation
    all_groups_count = len(contrasts)
    if horizontal:
        fig_size = (4, 1.5 * all_groups_count)
    else:
        fig_size = (1.5 * all_groups_count, 4)

    if ax is None:
        fig, ax = plt.subplots(figsize=fig_size)
    else:
        fig = ax.figure

    # Adjust violin plot orientation based on the 'horizontal' argument
    violin_kwargs = violin_kwargs or {
        "widths": 0.5,
        "showextrema": False,
        "showmedians": False,
    }
    violin_kwargs["vert"] = not horizontal
    v = ax.violinplot(bootstraps, **violin_kwargs)

    # Adjust the halfviolin function call based on 'horizontal'
    if horizontal:
        half = "top"
    else:
        half = "right"  # Assuming "right" is the default or another appropriate value

    # Assuming halfviolin has been updated to accept a 'half' parameter
    halfviolin(v, alpha=alpha_violin_plot, half=half)
    
    # Handle the custom color palette
    if custom_palette:
        if isinstance(custom_palette, dict):
            violin_colors = [
                custom_palette.get(c, sns.color_palette()[0]) for c in contrast_labels
            ]
        elif isinstance(custom_palette, list):
            violin_colors = custom_palette[: len(contrasts)]
        elif isinstance(custom_palette, str):
            if custom_palette in plt.colormaps():
                violin_colors = sns.color_palette(custom_palette, len(contrasts))
            else:
                raise ValueError(
                    f"The specified `custom_palette` {custom_palette} is not a recognized Matplotlib palette."
                )
    else:
        violin_colors = sns.color_palette(n_colors=len(contrasts))

    violin_colors = [sns.desaturate(color, desat_violin) for color in violin_colors]
    
    for patch, color in zip(v["bodies"], violin_colors):
        patch.set_facecolor(color)
        patch.set_alpha(alpha_violin_plot)
    if horizontal:
        ax.plot([0, 0], [0, len(contrasts)+1], 'k', linewidth = 1)
    else:
        ax.plot([0, len(contrasts)+1], [0, 0], 'k', linewidth = 1)
    
    # Flipping the axes for plotting based on 'horizontal'
    for k in range(1, len(contrasts) + 1):
        if horizontal:
            ax.plot(differences[k - 1], k, "k.", markersize=marker_size)  # Flipped axes
            ax.plot([bcalows[k - 1], bcahighs[k - 1]], [k, k], "k", linewidth=ci_line_width)  # Flipped axes
        else:
            ax.plot(k, differences[k - 1], "k.", markersize=marker_size)
            ax.plot([k, k], [bcalows[k - 1], bcahighs[k - 1]], "k", linewidth=ci_line_width)

    # Adjusting labels, ticks, and limits based on 'horizontal'
    if horizontal:
        ax.set_yticks(range(1, len(contrasts) + 1))
        ax.set_yticklabels(contrast_labels,  rotation=0, fontsize=fontsize)
        ax.set_xlabel(ylabel, fontsize=fontsize)
        ax.set_ylim([0.7, len(contrasts) + 0.5])
    else:
        ax.set_xticks(range(1, len(contrasts) + 1))
        ax.set_xticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
        ax.set_ylabel(ylabel, fontsize=fontsize)
        ax.set_xlim([0.7, len(contrasts) + 0.5])

    # Setting the title and adjusting spines as before
    ax.set_title(title, fontsize=title_font_size)
    if remove_spines:
        if horizontal:
            ax.spines['left'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
        else:
            ax.spines['top'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.spines['right'].set_visible(False)
    # Apply additional customizations if provided
    if additional_plotting_kwargs:
        ax.set(**additional_plotting_kwargs)

    return fig