Skip to content

Commit

Permalink
Add test for mdraw_legend (#88, #89) (#94)
Browse files Browse the repository at this point in the history
* Add test for mdraw_legend (#88, #89)

* Troubleshooting older py/mpl ver

* Troubleshooting older py/mpl ver

* Pleasing linters
  • Loading branch information
LSYS committed Dec 16, 2023
1 parent 9933ffa commit aeea499
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
35 changes: 35 additions & 0 deletions forestplot/mplot_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,41 @@ def mdraw_legend(
mcolor: Union[Sequence[str], None] = ["0", "0.4", ".8", "0.2"],
**kwargs: Any,
) -> Axes:
"""
Add a custom legend to a matplotlib Axes object for the different models.
This function creates and adds a legend to a given Axes object, allowing for customization of
the legend's markers, colors, size, and positioning. It's particularly useful for graphs
representing different models or categories with distinct markers and colors.
Parameters
----------
ax : Axes
The matplotlib Axes object to which the legend will be added.
xlabel : Union[Sequence[str], None]
A sequence of strings for x-axis labels, used to adjust the legend position. If None, the default position is used.
modellabels : Optional[Union[Sequence[str], None]]
A sequence of strings that serve as labels for the legend entries.
msymbols : Union[Sequence[str], None], optional
A sequence of marker symbols for each legend entry, defaults to 'soDx'.
mcolor : Union[Sequence[str], None], optional
A sequence of colors for each legend entry, defaults to ["0", "0.4", ".8", "0.2"].
**kwargs : Any
Additional keyword arguments for further customization. Supported customizations include 'leg_markersize'
(size of the legend markers, default 8), 'bbox_to_anchor' (tuple specifying the anchor point of the legend),
'leg_loc' (location of the legend, default 'lower center' or 'best'), 'leg_ncol' (number of columns in the legend,
default 2 or 1), and 'leg_fontsize' (font size of legend text, default 12).
Returns
-------
Axes
The modified matplotlib Axes object with the legend added.
Notes
-----
- The 'xlabel' parameter is used to adjust the legend's position based on the presence of x-axis labels.
It does not directly set the x-axis labels.
"""
leg_markersize = kwargs.get("leg_markersize", 8)
leg_artists = []
for ix, symbol in enumerate(msymbols):
Expand Down
32 changes: 32 additions & 0 deletions tests/test_mplot_graph_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.lines import Line2D
from matplotlib.pyplot import Axes

from forestplot.mplot_graph_utils import (
mdraw_ci,
mdraw_est_markers,
mdraw_legend,
mdraw_ref_xline,
mdraw_yticklabels,
)
Expand Down Expand Up @@ -95,3 +97,33 @@ def test_mdraw_ci():
# Assertions
assert isinstance(ax, Axes)
assert len(ax.collections) == len(set(models_vector))

def test_mdraw_legend():
# Create a simple plot
fig, ax = plt.subplots()
ax.plot([0, 1], [0, 1], marker="o", color="0")
ax.plot([0, 1], [1, 0], marker="s", color="0.4")

# Sample parameters for the legend
modellabels = ["Model 1", "Model 2"]
msymbols = ["o", "s"]
mcolor = ["0", "0.4"]

# Call the function
ax = mdraw_legend(ax, None, modellabels, msymbols, mcolor)

# Assertions
legend = ax.get_legend()
assert legend is not None, "Legend was not created."

# Check number of legend entries
assert len(legend.get_texts()) == len(modellabels), "Incorrect number of legend entries."

# Check legend labels
for label, model_label in zip(legend.get_texts(), modellabels):
assert label.get_text() == model_label, "Legend labels do not match."

# Check legend marker colors and symbols
for line, color in zip(legend.legendHandles, mcolor):
assert isinstance(line, Line2D), "Legend entry is not a Line2D instance."
assert line.get_color() == color, "Legend marker color does not match."

0 comments on commit aeea499

Please sign in to comment.