Skip to content

Commit

Permalink
Merge pull request #172 from ACCLAB/feat-forest-plot-pytest-fixes
Browse files Browse the repository at this point in the history
Fixing pytest failures, adding new Forestplot, tutorial notebook and image tests
  • Loading branch information
Jacobluke- committed Mar 14, 2024
2 parents d5b2884 + de164c4 commit 1cc16d3
Show file tree
Hide file tree
Showing 51 changed files with 1,466 additions and 157 deletions.
254 changes: 174 additions & 80 deletions dabest/forest_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@ def load_plot_data(
"""
Loads plot data based on specified effect size and contrast type.
Parameters:
contrasts (List): List of contrast objects.
effect_size (str): Type of effect size ('mean_diff', 'median_diff', etc.).
contrast_type (str): Type of contrast ('delta2', 'mini_meta').
Parameters
----------
contrasts : List
List of contrast objects.
effect_size: str
Type of effect size ('mean_diff', 'median_diff', etc.).
contrast_type: str
Type of contrast ('delta2', 'mini_meta').
Returns:
Returns
-------
List: Contrast plot data based on specified parameters.
"""
effect_attr_map = {
Expand All @@ -31,24 +36,27 @@ def load_plot_data(
"cliffs_delta": "cliffs_delta",
"cohens_d": "cohens_d",
"hedges_g": "hedges_g",
"delta_g": "delta_g"
}

contrast_attr_map = {"delta2": "delta_delta", "mini_meta": "mini_meta"}
contrast_attr_map = {"delta2": "delta_delta", "mini_meta": "mini_meta_delta"}

effect_attr = effect_attr_map.get(effect_size)
contrast_attr = contrast_attr_map.get(contrast_type, "delta_delta")
contrast_attr = contrast_attr_map.get(contrast_type)

if not effect_attr:
raise ValueError(f"Invalid effect_size: {effect_size}")
raise ValueError(f"Invalid effect_size: {effect_size}")
if not contrast_attr:
raise ValueError(f"Invalid contrast_type: {contrast_type}. Available options: [`delta2`, `mini_meta`]")

return [
getattr(getattr(contrast, effect_attr), contrast_attr) for contrast in contrasts
]


def extract_plot_data(contrast_plot_data, contrast_labels):
def extract_plot_data(contrast_plot_data, contrast_type):
"""Extracts bootstrap, difference, and confidence intervals based on contrast labels."""
if contrast_labels == "mini_meta":
if contrast_type == "mini_meta":
attribute_suffix = "weighted_delta"
else:
attribute_suffix = "delta_delta"
Expand All @@ -57,26 +65,25 @@ def extract_plot_data(contrast_plot_data, contrast_labels):
getattr(result, f"bootstraps_{attribute_suffix}")
for result in contrast_plot_data
]

differences = [result.difference for result in contrast_plot_data]
bcalows = [result.bca_low for result in contrast_plot_data]
bcahighs = [result.bca_high for result in contrast_plot_data]

return bootstraps, differences, bcalows, bcahighs


def forest_plot(
contrasts: List,
selected_indices: Optional[List] = None,
analysis_type: str = "delta2",
contrast_type: str = "delta2",
xticklabels: Optional[List] = None,
effect_size: str = "mean_diff",
contrast_labels: str = "delta_delta",
ylabel: str = "ΔΔ Volume (nL)",
contrast_labels: List[str] = None,
ylabel: str = "value",
plot_elements_to_extract: Optional[List] = None,
title: str = "ΔΔ Forest",
custom_palette: Optional[
Union[dict, list, str]
] = None, # Custom color palette parameter
custom_palette: Optional[Union[dict, list, str]] = None,
fontsize: int = 20,
violin_kwargs: Optional[dict] = None,
marker_size: int = 20,
Expand All @@ -87,73 +94,158 @@ def forest_plot(
additional_plotting_kwargs: Optional[dict] = None,
rotation_for_xlabels: int = 45,
alpha_violin_plot: float = 0.4,
) -> plt.Figure:
"""
Generates a customized forest plot using contrast objects from DABEST-python package or similar.
Parameters:
contrasts (List): List of contrast objects.
selected_indices (Optional[List]): Indices of contrasts to be plotted, if not all.
analysis_type (str): Type of analysis ('delta2', 'minimeta').
xticklabels (Optional[List]): Custom labels for x-axis ticks.
effect_size (str): Type of effect size ('mean_diff', 'median_diff', etc.).
contrast_labels (str): Labels for each contrast.
ylabel (str): Label for the y-axis.
plot_elements_to_extract (Optional[List]): Plot elements to be extracted for custom plotting.
title (str): Title of the plot.
ylim (Tuple[float, float]): y-axis limits.
custom_palette (Optional[Union[dict, list, str]]): Custom palette for violin plots.
fontsize (int): Font size for labels.
violin_kwargs (Optional[dict]): Additional kwargs for violin plots.
marker_size (int): Size of the markers for mean differences.
ci_line_width (float): Line width for confidence intervals.
zero_line_width (int): Width of the zero line.
remove_spines (bool): Whether to remove the plot spines.
ax (Optional[plt.Axes]): Axes object to plot on, if provided.
additional_plotting_kwargs (Optional[dict]): Additional plotting parameters.
rotation_for_xlabels (int): Rotation angle for x-axis labels.
alpha_violin_plot (float): Transparency level for violin plots.
Returns:
plt.Figure: The matplotlib figure object with the plot.
horizontal: bool = False # New argument for horizontal orientation
)-> plt.Figure:
"""
Custom function that generates a forest plot from given contrast objects, suitable for a range of data analysis types, including those from packages like DABEST-python.
Parameters
----------
contrasts : List
List of contrast objects.
selected_indices : Optional[List], default=None
Indices of specific contrasts to plot, if not plotting all.
analysis_type : str
the type of analysis (e.g., 'delta2', 'minimeta').
xticklabels : Optional[List], default=None
Custom labels for the x-axis ticks.
effect_size : str
Type of effect size to plot (e.g., 'mean_diff', 'median_diff').
contrast_labels : List[str]
Labels for each contrast.
ylabel : str
Label for the y-axis, describing the plotted data or effect size.
plot_elements_to_extract : Optional[List], default=None
Elements to extract for detailed plot customization.
title : str
Plot title, summarizing the visualized data.
ylim : Tuple[float, float]
Limits for the y-axis.
custom_palette : Optional[Union[dict, list, str]], default=None
Custom color palette for the plot.
fontsize : int
Font size for text elements in the plot.
violin_kwargs : Optional[dict], default=None
Additional arguments for violin plot customization.
marker_size : int
Marker size for plotting mean differences or effect sizes.
ci_line_width : float
Width of confidence interval lines.
zero_line_width : int
Width of the line indicating zero effect size.
remove_spines : bool, default=False
If True, removes top and right plot spines.
ax : Optional[plt.Axes], default=None
Matplotlib Axes object for the plot; creates new if None.
additional_plotting_kwargs : Optional[dict], default=None
Further customization arguments for the plot.
rotation_for_xlabels : int, default=0
Rotation angle for x-axis labels, improving readability.
alpha_violin_plot : float, default=1.0
Transparency level for violin plots.
Returns
-------
plt.Figure
The matplotlib figure object with the generated forest plot.
"""
from .plot_tools import halfviolin

# Validate inputs
if contrasts is None:
raise ValueError("The `contrasts` parameter cannot be None")

if not isinstance(contrasts, list) or not contrasts:
raise ValueError("The `contrasts` argument must be a non-empty list.")

if selected_indices is not None and not isinstance(selected_indices, (list, type(None))):
raise TypeError("The `selected_indices` must be a list of integers or `None`.")

if not isinstance(contrast_type, str):
raise TypeError("The `contrast_type` argument must be a string.")

if xticklabels is not None and not all(isinstance(label, str) for label in xticklabels):
raise TypeError("The `xticklabels` must be a list of strings or `None`.")

if not isinstance(effect_size, str):
raise TypeError("The `effect_size` argument must be a string.")

if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):
raise TypeError("The `contrast_labels` must be a list of strings or `None`.")

if contrast_labels is not None and len(contrast_labels) != len(contrasts):
raise ValueError("`contrast_labels` must match the number of `contrasts` if provided.")

if not isinstance(ylabel, str):
raise TypeError("The `ylabel` argument must be a string.")

if custom_palette is not None and not isinstance(custom_palette, (dict, list, str, type(None))):
raise TypeError("The `custom_palette` must be either a dictionary, list, string, or `None`.")

if not isinstance(fontsize, (int, float)):
raise TypeError("`fontsize` must be an integer or float.")

if not isinstance(marker_size, (int, float)) or marker_size <= 0:
raise TypeError("`marker_size` must be a positive integer or float.")

if not isinstance(ci_line_width, (int, float)) or ci_line_width <= 0:
raise TypeError("`ci_line_width` must be a positive integer or float.")

if not isinstance(zero_line_width, (int, float)) or zero_line_width <= 0:
raise TypeError("`zero_line_width` must be a positive integer or float.")

if not isinstance(remove_spines, bool):
raise TypeError("`remove_spines` must be a boolean value.")

if ax is not None and not isinstance(ax, plt.Axes):
raise TypeError("`ax` must be a `matplotlib.axes.Axes` instance or `None`.")

if not isinstance(rotation_for_xlabels, (int, float)) or not 0 <= rotation_for_xlabels <= 360:
raise TypeError("`rotation_for_xlabels` must be an integer or float between 0 and 360.")

if not isinstance(alpha_violin_plot, float) or not 0 <= alpha_violin_plot <= 1:
raise TypeError("`alpha_violin_plot` must be a float between 0 and 1.")

if not isinstance(horizontal, bool):
raise TypeError("`horizontal` must be a boolean value.")

# Load plot data
contrast_plot_data = load_plot_data(contrasts, effect_size, analysis_type)
contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)

# Extract data for plotting
bootstraps, differences, bcalows, bcahighs = extract_plot_data(
contrast_plot_data, contrast_labels
contrast_plot_data, contrast_type
)

# Infer the figsize based on the number of contrasts
# Adjust figure size based on orientation
all_groups_count = len(contrasts)
each_group_width_inches = 2.5 # Adjust as needed for width
base_height_inches = 4 # Base height, adjust as needed
height_inches = base_height_inches
width_inches = each_group_width_inches * all_groups_count
fig_size = (width_inches, height_inches)
if horizontal:
fig_size = (4, 1.5 * all_groups_count)
else:
fig_size = (1.5 * all_groups_count, 4)

# Create figure and axes if not provided
if ax is None:
fig, ax = plt.subplots(figsize=fig_size)
else:
fig = ax.figure

# Zero line
ax.plot([0, len(contrasts) + 1], [0, 0], "k", linewidth=zero_line_width)

# Violin plots with customizable colors
# Adjust violin plot orientation based on the 'horizontal' argument
violin_kwargs = violin_kwargs or {
"widths": 0.5,
"vert": True,
"showextrema": False,
"showmedians": False,
}
violin_kwargs["vert"] = not horizontal
v = ax.violinplot(bootstraps, **violin_kwargs)
halfviolin(v, alpha=alpha_violin_plot) # Apply halfviolin from dabest

# Adjust the halfviolin function call based on 'horizontal'
if horizontal:
half = "top"
else:
half = "right" # Assuming "right" is the default or another appropriate value

# Assuming halfviolin has been updated to accept a 'half' parameter
halfviolin(v, alpha=alpha_violin_plot, half=half)

# Handle the custom color palette
if custom_palette:
if isinstance(custom_palette, dict):
Expand All @@ -176,30 +268,32 @@ def forest_plot(
patch.set_facecolor(color)
patch.set_alpha(alpha_violin_plot)

# Effect size dot and confidence interval
# Flipping the axes for plotting based on 'horizontal'
for k in range(1, len(contrasts) + 1):
ax.plot(k, differences[k - 1], "k.", markersize=marker_size)
ax.plot([k, k], [bcalows[k - 1], bcahighs[k - 1]], "k", linewidth=ci_line_width)

# Custom settings
ax.set_xticks(range(1, len(contrasts) + 1))
ax.set_xticklabels(
xticklabels or range(1, len(contrasts) + 1),
rotation=rotation_for_xlabels,
fontsize=fontsize,
)
ax.set_xlim([0, len(contrasts) + 1])
ax.set_ylabel(ylabel, fontsize=fontsize)
ax.set_title(title, fontsize=fontsize)
ylim = (min(bcalows) - 0.25, max(bcahighs) + 0.25)
ax.set_ylim(ylim)
if horizontal:
ax.plot(differences[k - 1], k, "k.", markersize=marker_size) # Flipped axes
ax.plot([bcalows[k - 1], bcahighs[k - 1]], [k, k], "k", linewidth=ci_line_width) # Flipped axes
else:
ax.plot(k, differences[k - 1], "k.", markersize=marker_size)
ax.plot([k, k], [bcalows[k - 1], bcahighs[k - 1]], "k", linewidth=ci_line_width)

# Remove spines if requested
# Adjusting labels, ticks, and limits based on 'horizontal'
if horizontal:
ax.set_yticks(range(1, len(contrasts) + 1))
ax.set_yticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
ax.set_xlabel(ylabel, fontsize=fontsize)
else:
ax.set_xticks(range(1, len(contrasts) + 1))
ax.set_xticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)

# Setting the title and adjusting spines as before
ax.set_title(title, fontsize=fontsize)
if remove_spines:
for spine in ax.spines.values():
spine.set_visible(False)

# Additional customization
# Apply additional customizations if provided
if additional_plotting_kwargs:
ax.set(**additional_plotting_kwargs)

Expand Down

0 comments on commit 1cc16d3

Please sign in to comment.