Skip to content

Commit

Permalink
tst
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacobluke- committed Apr 16, 2024
1 parent 9b0c918 commit cb195d8
Show file tree
Hide file tree
Showing 103 changed files with 109 additions and 67 deletions.
81 changes: 52 additions & 29 deletions dabest/forest_plot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/API/forest_plot.ipynb.

# %% auto 0
__all__ = ['load_plot_data', 'extract_plot_data', 'forest_plot']
__all__ = ['load_plot_data', 'extract_plot_data', 'map_effect_attribute', 'forest_plot']

# %% ../nbs/API/forest_plot.ipynb 5
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -72,28 +72,42 @@ def extract_plot_data(contrast_plot_data, contrast_type):

return bootstraps, differences, bcalows, bcahighs

def map_effect_attribute(attribute_key):
# Check if the attribute key exists in the dictionary
effect_attr_map = {
"mean_diff": "Mean Difference",
"median_diff": "Median Difference",
"cliffs_delta": "Cliffs Delta",
"cohens_d": "Cohens d",
"hedges_g": "Hedges g",
"delta_g": "Delta g"
}
if attribute_key in effect_attr_map:
return effect_attr_map[attribute_key]
else:
raise TypeError("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`.") # Return a default value or message if the key is not found

def forest_plot(
contrasts: List,
selected_indices: Optional[List] = None,
contrast_type: str = "delta2",
xticklabels: Optional[List] = None,
effect_size: str = "mean_diff",
contrast_labels: List[str] = None,
ylabel: str = "value",
ylabel: str = "effect size",
plot_elements_to_extract: Optional[List] = None,
title: str = "ΔΔ Forest",
custom_palette: Optional[Union[dict, list, str]] = None,
fontsize: int = 20,
fontsize: int = 12,
title_font_size: int =16,
violin_kwargs: Optional[dict] = None,
marker_size: int = 20,
ci_line_width: float = 2.5,
zero_line_width: int = 1,
desat_violin: float = 1,
remove_spines: bool = True,
ax: Optional[plt.Axes] = None,
additional_plotting_kwargs: Optional[dict] = None,
rotation_for_xlabels: int = 45,
alpha_violin_plot: float = 0.4,
alpha_violin_plot: float = 0.8,
horizontal: bool = False # New argument for horizontal orientation
)-> plt.Figure:
"""
Expand All @@ -106,11 +120,9 @@ def forest_plot(
selected_indices : Optional[List], default=None
Indices of specific contrasts to plot, if not plotting all.
analysis_type : str
the type of analysis (e.g., 'delta2', 'minimeta').
xticklabels : Optional[List], default=None
Custom labels for the x-axis ticks.
the type of analysis (e.g., 'delta2', 'mini_meta').
effect_size : str
Type of effect size to plot (e.g., 'mean_diff', 'median_diff').
Type of effect size to plot (e.g., 'mean_diff', 'median_diff', `cliffs_delta`,`cohens_d``, and `hedges_g`).
contrast_labels : List[str]
Labels for each contrast.
ylabel : str
Expand All @@ -125,14 +137,14 @@ def forest_plot(
Custom color palette for the plot.
fontsize : int
Font size for text elements in the plot.
title_font_size: int =16
Font size for text of plot title.
violin_kwargs : Optional[dict], default=None
Additional arguments for violin plot customization.
marker_size : int
Marker size for plotting mean differences or effect sizes.
ci_line_width : float
Width of confidence interval lines.
zero_line_width : int
Width of the line indicating zero effect size.
remove_spines : bool, default=False
If True, removes top and right plot spines.
ax : Optional[plt.Axes], default=None
Expand Down Expand Up @@ -161,14 +173,13 @@ def forest_plot(
if selected_indices is not None and not isinstance(selected_indices, (list, type(None))):
raise TypeError("The `selected_indices` must be a list of integers or `None`.")

# For the 'contrast_type' parameter
if not isinstance(contrast_type, str):
raise TypeError("The `contrast_type` argument must be a string.")

if xticklabels is not None and not all(isinstance(label, str) for label in xticklabels):
raise TypeError("The `xticklabels` must be a list of strings or `None`.")

raise TypeError("The `contrast_type` argument must be a string. Please choose from `delta2` and `mini_meta`.")

# For the 'effect_size' parameter
if not isinstance(effect_size, str):
raise TypeError("The `effect_size` argument must be a string.")
raise TypeError("The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`, `median_diff`, `cliffs_delta`, `cohens_d`, and `hedges_g`.")

if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):
raise TypeError("The `contrast_labels` must be a list of strings or `None`.")
Expand All @@ -191,9 +202,6 @@ def forest_plot(
if not isinstance(ci_line_width, (int, float)) or ci_line_width <= 0:
raise TypeError("`ci_line_width` must be a positive integer or float.")

if not isinstance(zero_line_width, (int, float)) or zero_line_width <= 0:
raise TypeError("`zero_line_width` must be a positive integer or float.")

if not isinstance(remove_spines, bool):
raise TypeError("`remove_spines` must be a boolean value.")

Expand All @@ -209,6 +217,8 @@ def forest_plot(
if not isinstance(horizontal, bool):
raise TypeError("`horizontal` must be a boolean value.")

if (effect_size and isinstance(effect_size, str)):
ylabel = map_effect_attribute(effect_size)
# Load plot data
contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)

Expand Down Expand Up @@ -250,7 +260,7 @@ def forest_plot(
if custom_palette:
if isinstance(custom_palette, dict):
violin_colors = [
custom_palette.get(c, sns.color_palette()[0]) for c in contrasts
custom_palette.get(c, sns.color_palette()[0]) for c in contrast_labels
]
elif isinstance(custom_palette, list):
violin_colors = custom_palette[: len(contrasts)]
Expand All @@ -262,12 +272,18 @@ def forest_plot(
f"The specified `custom_palette` {custom_palette} is not a recognized Matplotlib palette."
)
else:
violin_colors = sns.color_palette()[: len(contrasts)]
violin_colors = sns.color_palette(n_colors=len(contrasts))

violin_colors = [sns.desaturate(color, desat_violin) for color in violin_colors]

for patch, color in zip(v["bodies"], violin_colors):
patch.set_facecolor(color)
patch.set_alpha(alpha_violin_plot)

if horizontal:
ax.plot([0, 0], [0, len(contrasts)+1], 'k', linewidth = 1)
else:
ax.plot([0, len(contrasts)+1], [0, 0], 'k', linewidth = 1)

# Flipping the axes for plotting based on 'horizontal'
for k in range(1, len(contrasts) + 1):
if horizontal:
Expand All @@ -280,19 +296,26 @@ def forest_plot(
# Adjusting labels, ticks, and limits based on 'horizontal'
if horizontal:
ax.set_yticks(range(1, len(contrasts) + 1))
ax.set_yticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
ax.set_yticklabels(contrast_labels, rotation=0, fontsize=fontsize)
ax.set_xlabel(ylabel, fontsize=fontsize)
ax.set_ylim([0.7, len(contrasts) + 0.5])
else:
ax.set_xticks(range(1, len(contrasts) + 1))
ax.set_xticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)
ax.set_xlim([0.7, len(contrasts) + 0.5])

# Setting the title and adjusting spines as before
ax.set_title(title, fontsize=fontsize)
ax.set_title(title, fontsize=title_font_size)
if remove_spines:
for spine in ax.spines.values():
spine.set_visible(False)

if horizontal:
ax.spines['left'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
else:
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['right'].set_visible(False)
# Apply additional customizations if provided
if additional_plotting_kwargs:
ax.set(**additional_plotting_kwargs)
Expand Down
79 changes: 51 additions & 28 deletions nbs/API/forest_plot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -133,28 +133,42 @@
" \n",
" return bootstraps, differences, bcalows, bcahighs\n",
"\n",
"def map_effect_attribute(attribute_key):\n",
" # Check if the attribute key exists in the dictionary\n",
" effect_attr_map = {\n",
" \"mean_diff\": \"Mean Difference\",\n",
" \"median_diff\": \"Median Difference\",\n",
" \"cliffs_delta\": \"Cliffs Delta\",\n",
" \"cohens_d\": \"Cohens d\",\n",
" \"hedges_g\": \"Hedges g\",\n",
" \"delta_g\": \"Delta g\"\n",
" }\n",
" if attribute_key in effect_attr_map:\n",
" return effect_attr_map[attribute_key]\n",
" else:\n",
" raise TypeError(\"The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`,`median_diff`,`cliffs_delta`,`cohens_d``, and `hedges_g`.\") # Return a default value or message if the key is not found\n",
"\n",
"def forest_plot(\n",
" contrasts: List,\n",
" selected_indices: Optional[List] = None,\n",
" contrast_type: str = \"delta2\",\n",
" xticklabels: Optional[List] = None,\n",
" effect_size: str = \"mean_diff\",\n",
" contrast_labels: List[str] = None,\n",
" ylabel: str = \"value\",\n",
" ylabel: str = \"effect size\",\n",
" plot_elements_to_extract: Optional[List] = None,\n",
" title: str = \"ΔΔ Forest\",\n",
" custom_palette: Optional[Union[dict, list, str]] = None,\n",
" fontsize: int = 20,\n",
" fontsize: int = 12,\n",
" title_font_size: int =16,\n",
" violin_kwargs: Optional[dict] = None,\n",
" marker_size: int = 20,\n",
" ci_line_width: float = 2.5,\n",
" zero_line_width: int = 1,\n",
" desat_violin: float = 1,\n",
" remove_spines: bool = True,\n",
" ax: Optional[plt.Axes] = None,\n",
" additional_plotting_kwargs: Optional[dict] = None,\n",
" rotation_for_xlabels: int = 45,\n",
" alpha_violin_plot: float = 0.4,\n",
" alpha_violin_plot: float = 0.8,\n",
" horizontal: bool = False # New argument for horizontal orientation\n",
")-> plt.Figure:\n",
" \"\"\" \n",
Expand All @@ -167,11 +181,9 @@
" selected_indices : Optional[List], default=None\n",
" Indices of specific contrasts to plot, if not plotting all.\n",
" analysis_type : str\n",
" the type of analysis (e.g., 'delta2', 'minimeta').\n",
" xticklabels : Optional[List], default=None\n",
" Custom labels for the x-axis ticks.\n",
" the type of analysis (e.g., 'delta2', 'mini_meta').\n",
" effect_size : str\n",
" Type of effect size to plot (e.g., 'mean_diff', 'median_diff').\n",
" Type of effect size to plot (e.g., 'mean_diff', 'median_diff', `cliffs_delta`,`cohens_d``, and `hedges_g`).\n",
" contrast_labels : List[str]\n",
" Labels for each contrast.\n",
" ylabel : str\n",
Expand All @@ -186,14 +198,14 @@
" Custom color palette for the plot.\n",
" fontsize : int\n",
" Font size for text elements in the plot.\n",
" title_font_size: int =16\n",
" Font size for text of plot title.\n",
" violin_kwargs : Optional[dict], default=None\n",
" Additional arguments for violin plot customization.\n",
" marker_size : int\n",
" Marker size for plotting mean differences or effect sizes.\n",
" ci_line_width : float\n",
" Width of confidence interval lines.\n",
" zero_line_width : int\n",
" Width of the line indicating zero effect size.\n",
" remove_spines : bool, default=False\n",
" If True, removes top and right plot spines.\n",
" ax : Optional[plt.Axes], default=None\n",
Expand Down Expand Up @@ -222,14 +234,13 @@
" if selected_indices is not None and not isinstance(selected_indices, (list, type(None))):\n",
" raise TypeError(\"The `selected_indices` must be a list of integers or `None`.\")\n",
" \n",
" # For the 'contrast_type' parameter\n",
" if not isinstance(contrast_type, str):\n",
" raise TypeError(\"The `contrast_type` argument must be a string.\")\n",
" \n",
" if xticklabels is not None and not all(isinstance(label, str) for label in xticklabels):\n",
" raise TypeError(\"The `xticklabels` must be a list of strings or `None`.\")\n",
" \n",
" raise TypeError(\"The `contrast_type` argument must be a string. Please choose from `delta2` and `mini_meta`.\")\n",
"\n",
" # For the 'effect_size' parameter\n",
" if not isinstance(effect_size, str):\n",
" raise TypeError(\"The `effect_size` argument must be a string.\")\n",
" raise TypeError(\"The `effect_size` argument must be a string. Please choose from the following effect sizes: `mean_diff`, `median_diff`, `cliffs_delta`, `cohens_d`, and `hedges_g`.\")\n",
" \n",
" if contrast_labels is not None and not all(isinstance(label, str) for label in contrast_labels):\n",
" raise TypeError(\"The `contrast_labels` must be a list of strings or `None`.\")\n",
Expand All @@ -252,9 +263,6 @@
" if not isinstance(ci_line_width, (int, float)) or ci_line_width <= 0:\n",
" raise TypeError(\"`ci_line_width` must be a positive integer or float.\")\n",
" \n",
" if not isinstance(zero_line_width, (int, float)) or zero_line_width <= 0:\n",
" raise TypeError(\"`zero_line_width` must be a positive integer or float.\")\n",
" \n",
" if not isinstance(remove_spines, bool):\n",
" raise TypeError(\"`remove_spines` must be a boolean value.\")\n",
" \n",
Expand All @@ -270,6 +278,8 @@
" if not isinstance(horizontal, bool):\n",
" raise TypeError(\"`horizontal` must be a boolean value.\")\n",
"\n",
" if (effect_size and isinstance(effect_size, str)):\n",
" ylabel = map_effect_attribute(effect_size)\n",
" # Load plot data\n",
" contrast_plot_data = load_plot_data(contrasts, effect_size, contrast_type)\n",
"\n",
Expand Down Expand Up @@ -311,7 +321,7 @@
" if custom_palette:\n",
" if isinstance(custom_palette, dict):\n",
" violin_colors = [\n",
" custom_palette.get(c, sns.color_palette()[0]) for c in contrasts\n",
" custom_palette.get(c, sns.color_palette()[0]) for c in contrast_labels\n",
" ]\n",
" elif isinstance(custom_palette, list):\n",
" violin_colors = custom_palette[: len(contrasts)]\n",
Expand All @@ -323,12 +333,18 @@
" f\"The specified `custom_palette` {custom_palette} is not a recognized Matplotlib palette.\"\n",
" )\n",
" else:\n",
" violin_colors = sns.color_palette()[: len(contrasts)]\n",
" violin_colors = sns.color_palette(n_colors=len(contrasts))\n",
"\n",
" violin_colors = [sns.desaturate(color, desat_violin) for color in violin_colors]\n",
" \n",
" for patch, color in zip(v[\"bodies\"], violin_colors):\n",
" patch.set_facecolor(color)\n",
" patch.set_alpha(alpha_violin_plot)\n",
"\n",
" if horizontal:\n",
" ax.plot([0, 0], [0, len(contrasts)+1], 'k', linewidth = 1)\n",
" else:\n",
" ax.plot([0, len(contrasts)+1], [0, 0], 'k', linewidth = 1)\n",
" \n",
" # Flipping the axes for plotting based on 'horizontal'\n",
" for k in range(1, len(contrasts) + 1):\n",
" if horizontal:\n",
Expand All @@ -341,19 +357,26 @@
" # Adjusting labels, ticks, and limits based on 'horizontal'\n",
" if horizontal:\n",
" ax.set_yticks(range(1, len(contrasts) + 1))\n",
" ax.set_yticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)\n",
" ax.set_yticklabels(contrast_labels, rotation=0, fontsize=fontsize)\n",
" ax.set_xlabel(ylabel, fontsize=fontsize)\n",
" ax.set_ylim([0.7, len(contrasts) + 0.5])\n",
" else:\n",
" ax.set_xticks(range(1, len(contrasts) + 1))\n",
" ax.set_xticklabels(contrast_labels, rotation=rotation_for_xlabels, fontsize=fontsize)\n",
" ax.set_ylabel(ylabel, fontsize=fontsize)\n",
" ax.set_xlim([0.7, len(contrasts) + 0.5])\n",
"\n",
" # Setting the title and adjusting spines as before\n",
" ax.set_title(title, fontsize=fontsize)\n",
" ax.set_title(title, fontsize=title_font_size)\n",
" if remove_spines:\n",
" for spine in ax.spines.values():\n",
" spine.set_visible(False)\n",
"\n",
" if horizontal:\n",
" ax.spines['left'].set_visible(False)\n",
" ax.spines['right'].set_visible(False)\n",
" ax.spines['top'].set_visible(False)\n",
" else:\n",
" ax.spines['top'].set_visible(False)\n",
" ax.spines['bottom'].set_visible(False)\n",
" ax.spines['right'].set_visible(False)\n",
" # Apply additional customizations if provided\n",
" if additional_plotting_kwargs:\n",
" ax.set(**additional_plotting_kwargs)\n",
Expand Down
9 changes: 4 additions & 5 deletions nbs/tests/data/mocked_data_test_forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,20 @@
"contrasts": dummy_contrasts, # Ensure this is a list of contrast objects.
"selected_indices": None, # Valid as None or a list of integers.
"contrast_type": "delta2", # Ensure it's a string and one of the allowed contrast types.
"xticklabels": None, # Valid as None or a list of strings.
"effect_size": "mean_diff", # Ensure it's a string.
"contrast_labels": ["Drug1"], # This should be a list of strings.
"ylabel": "Effect Size", # Ensure it's a string.
"plot_elements_to_extract": None, # No specific checks needed based on your tests.
"title": "ΔΔ Forest Plot", # Ensure it's a string.
#"plot_elements_to_extract": None, # No specific checks needed based on your tests.
#"title": "ΔΔ Forest Plot", # Ensure it's a string.
"custom_palette": None, # Valid as None, a dictionary, list, or string.
"fontsize": 20, # Ensure it's an integer or float.
"violin_kwargs": None, # No specific checks needed based on your tests.
"marker_size": 20, # Ensure it's a positive integer or float.
"ci_line_width": 2.5, # Ensure it's a positive integer or float.
"zero_line_width": 1, # Ensure it's a positive integer or float.
"remove_spines": True, # Ensure it's a boolean.
"additional_plotting_kwargs": None, # No specific checks needed based on your tests.
"rotation_for_xlabels": 45, # Ensure it's an integer or float between 0 and 360.
"alpha_violin_plot": 0.4, # Ensure it's a float between 0 and 1.
"alpha_violin_plot": 0.8, # Ensure it's a float between 0 and 1.
"horizontal": False, # Ensure it's a boolean.
}

Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified nbs/tests/mpl_image_tests/baseline_images/test_113_desat.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified nbs/tests/mpl_image_tests/baseline_images/test_115_invert_ylim.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified nbs/tests/mpl_image_tests/baseline_images/test_117_err_color.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified nbs/tests/mpl_image_tests/baseline_images/test_119_wide_df_nan.png
Binary file modified nbs/tests/mpl_image_tests/baseline_images/test_11_inset_plots.png
Binary file modified nbs/tests/mpl_image_tests/baseline_images/test_120_long_df_nan.png
Binary file modified nbs/tests/mpl_image_tests/baseline_images/test_130_zero_to_one.png
Binary file modified nbs/tests/mpl_image_tests/baseline_images/test_131_one_to_zero.png
Binary file modified nbs/tests/mpl_image_tests/baseline_images/test_18_desat.png
Binary file modified nbs/tests/mpl_image_tests/baseline_images/test_19_dot_sizes.png
Binary file modified nbs/tests/mpl_image_tests/baseline_images/test_20_change_ylims.png
Binary file modified nbs/tests/mpl_image_tests/baseline_images/test_21_invert_ylim.png
Binary file modified nbs/tests/mpl_image_tests/baseline_images/test_24_wide_df_nan.png
Binary file modified nbs/tests/mpl_image_tests/baseline_images/test_25_long_df_nan.png
Binary file modified nbs/tests/mpl_image_tests/baseline_images/test_99_style_sheets.png

0 comments on commit cb195d8

Please sign in to comment.