diff --git a/forestplot/mplot_graph_utils.py b/forestplot/mplot_graph_utils.py index a10bdd3..edd381a 100644 --- a/forestplot/mplot_graph_utils.py +++ b/forestplot/mplot_graph_utils.py @@ -50,18 +50,37 @@ def mdraw_ref_xline( return ax -# ============================================================================================= -# ============================================================================================= -# ============================================================================================= def mdraw_yticklabels( dataframe: pd.core.frame.DataFrame, yticklabel: str, - model_col: str, - models: Optional[Union[Sequence[str], None]], flush: bool, ax: Axes, **kwargs: Any, ) -> Axes: + """ + Set custom y-axis tick labels on a matplotlib Axes object using the yticklabel column in the provided + pandas dataframe. + + Parameters + ---------- + dataframe : pd.core.frame.DataFrame + The pandas DataFrame from which the y-axis tick labels are derived. + yticklabel : str + Column name in the DataFrame whose values are used as y-axis tick labels. + flush : bool + If True, aligns y-axis tick labels to the left with adjusted padding to prevent overlap. + If False, aligns labels to the right. + ax : Axes + The matplotlib Axes object to be modified. + **kwargs : Any + Additional keyword arguments for customizing the appearance of the tick labels. + Supported customizations include 'fontfamily' (default 'monospace') and 'fontsize' (default 12). + + Returns + ------- + Axes + The modified matplotlib Axes object with updated y-axis tick labels. + """ ax.set_yticks(range(len(dataframe))) fontfamily = kwargs.get("fontfamily", "monospace") @@ -72,10 +91,16 @@ def mdraw_yticklabels( ) yax = ax.get_yaxis() fig = plt.gcf() - pad = max( - T.label.get_window_extent(renderer=fig.canvas.get_renderer()).width - for T in yax.majorTicks - ) + try: + pad = max( + T.label.get_window_extent(renderer=fig.canvas.get_renderer()).width + for T in yax.majorTicks + ) + except AttributeError: + pad = max( + T.label1.get_window_extent(renderer=fig.canvas.get_renderer()).width + for T in yax.majorTicks + ) yax.set_tick_params(pad=pad) else: ax.set_yticklabels( diff --git a/tests/test_mplot_graph_utils.py b/tests/test_mplot_graph_utils.py new file mode 100644 index 0000000..bcc5e9b --- /dev/null +++ b/tests/test_mplot_graph_utils.py @@ -0,0 +1,52 @@ +import matplotlib.pyplot as plt +import pandas as pd +from matplotlib.pyplot import Axes + +from forestplot.mplot_graph_utils import mdraw_ref_xline, mdraw_yticklabels + +x, y = [0, 1, 2], [0, 1, 2] +str_vector = ["a", "b", "c"] +input_df = pd.DataFrame( + { + "yticklabel": str_vector, + "estimate": x, + "moerror": y, + "ll": x, + "hl": y, + "pval": y, + "formatted_pval": y, + "yticklabel1": str_vector, + "yticklabel2": str_vector, + } +) + + +def test_mdraw_ref_xline(): + _, ax = plt.subplots() + ax = mdraw_ref_xline( + ax, + dataframe=input_df, + model_col="yticklabel", + annoteheaders=None, + right_annoteheaders=None, + ) + assert isinstance(ax, Axes) + + +def test_mdraw_yticklabels(): + # Prepare the input DataFrame + str_vector = ["a", "b", "c"] + input_df = pd.DataFrame( + { + "yticklabel": str_vector, + } + ) + + # Create a matplotlib Axes object + _, ax = plt.subplots() + + # Call the function + ax = mdraw_yticklabels(input_df, yticklabel="yticklabel", flush=True, ax=ax) + + assert isinstance(ax, Axes) + assert [label.get_text() for label in ax.get_yticklabels()] == str_vector