Skip to content

Commit

Permalink
Add docstring & test for mplot (#88, #89) (#96)
Browse files Browse the repository at this point in the history
* Add docstring & test for mplot (#88, #89)

* Pleasing linters
  • Loading branch information
LSYS committed Dec 23, 2023
1 parent b406f7a commit 33bf8d0
Show file tree
Hide file tree
Showing 4 changed files with 2,727 additions and 5 deletions.
2,514 changes: 2,514 additions & 0 deletions examples/test-multmodel-sleep.ipynb

Large diffs are not rendered by default.

59 changes: 58 additions & 1 deletion forestplot/mplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,60 @@ def mforestplot(
table: bool = False,
**kwargs: Any,
) -> Axes:
"""
Generate a forest plot from a DataFrame using Matplotlib.
This function creates a forest plot, which is useful for displaying the estimates from different models
or groups, along with their confidence intervals. It provides a range of customization options for the plot,
including sorting, annotations, and visual style.
Parameters
----------
dataframe : pd.core.frame.DataFrame
The DataFrame containing the data to be plotted.
estimate : str
The name of the column in the DataFrame that contains the estimate values.
varlabel : str
The name of the column used for variable labels on the y-axis.
model_col : str
The name of the column that categorizes data into different models or groups.
models : Optional[Sequence[str]]
The list of models to include in the plot. If None, all models in model_col are used.
modellabels : Optional[Sequence[str]]
Labels for the models, used in the plot legend. If None, model names are used as labels.
ll : Optional[str]
The name of the column representing the lower limit of the confidence intervals.
hl : Optional[str]
The name of the column representing the upper limit of the confidence intervals.
[Other parameters]
...
Returns
-------
Tuple
A tuple containing a modified DataFrame (if return_df is True) and the matplotlib Axes object
with the forest plot.
Examples
--------
>>> df = pd.DataFrame({
... 'model': ['model1', 'model2'],
... 'estimate': [1.5, 2.0],
... 'll': [1.0, 1.7],
... 'hl': [2.0, 2.3],
... 'varlabel': ['Variable 1', 'Variable 2']
... })
>>> modified_df, ax = mforestplot(df, 'estimate', 'varlabel', 'model')
>>> plt.show()
Notes
-----
- The function is highly customizable with several optional parameters to adjust the appearance and functionality
of the plot.
- If `return_df` is True, the function also returns the DataFrame after preprocessing and sorting based on the
specified parameters.
- The `preprocess` parameter controls whether the input DataFrame should be preprocessed before plotting.
"""
_local_df = dataframe.copy(deep=True)
_local_df = check_data(
dataframe=_local_df,
Expand Down Expand Up @@ -145,7 +199,10 @@ def mforestplot(
table=table,
**kwargs,
)
return _local_df, ax
if return_df:
return _local_df, ax
else:
return ax


def _mpreprocess_dataframe(
Expand Down
106 changes: 102 additions & 4 deletions forestplot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,60 @@ def _preprocess_dataframe(
**kwargs: Any,
) -> pd.core.frame.DataFrame:
"""
Preprocess the dataframe to be ready for plotting.
Preprocess a DataFrame for forest plot visualization, handling various aspects such as sorting,
normalizing labels, and formatting annotations.
Parameters
----------
dataframe : pd.core.frame.DataFrame
The DataFrame to preprocess.
estimate : str
The column name for the estimate values.
varlabel : str
The column name for the variable labels.
ll : Optional[str], default=None
The column name for the lower limit of confidence intervals.
hl : Optional[str], default=None
The column name for the upper limit of confidence intervals.
form_ci_report : Optional[bool], default=False
Flag to determine if confidence interval reporting is required.
ci_report : Optional[bool], default=False
Flag to determine if confidence interval reporting is enabled.
groupvar : Optional[str], default=None
The column name for group variables.
group_order : Optional[Union[list, tuple]], default=None
The order of groups for sorting.
annote : Optional[Union[Sequence[str], None]], default=None
Annotations to add to the DataFrame.
annoteheaders : Optional[Union[Sequence[str], None]], default=None
Headers for the annotations.
rightannote : Optional[Union[Sequence[str], None]], default=None
Right-aligned annotations to add to the DataFrame.
right_annoteheaders : Optional[Union[Sequence[str], None]], default=None
Headers for the right-aligned annotations.
capitalize : Optional[str], default=None
Flag to capitalize certain text elements.
pval : Optional[str], default=None
The column name for p-values.
starpval : bool, default=True
Flag to add stars to significant p-values.
sort : bool, default=False
Flag to enable sorting.
sortby : Optional[str], default=None
The column name to sort by.
sortascend : bool, default=True
Flag to set sorting order.
flush : bool, default=True
Flag to flush certain text elements.
decimal_precision : int, default=2
The number of decimal places for rounding numeric values.
**kwargs : Any
Additional keyword arguments.
Returns
-------
pd.core.frame.DataFrame with additional columns for plotting.
pd.core.frame.DataFrame
The preprocessed DataFrame, ready for visualization with additional columns for plotting.
"""
if (groupvar is not None) and (group_order is not None):
if sort is True:
Expand Down Expand Up @@ -351,11 +400,60 @@ def _make_forestplot(
**kwargs: Any,
) -> Axes:
"""
Draw the forest plot.
Create and draw a forest plot using the given DataFrame and specified parameters.
This function sets up and renders a forest plot using matplotlib, with various options for customization,
including confidence intervals, marker styles, axis labels, and annotations.
Parameters
----------
dataframe : pd.core.frame.DataFrame
The DataFrame containing the data to be plotted.
yticklabel : str
The column name to be used for y-axis tick labels.
estimate : str
The column name representing the central estimate for each observation.
groupvar : str
The column name used for grouping variables in the plot.
pval : str
The column name for the p-value.
xticks : Optional[Union[list, range]]
Custom x-ticks for the plot.
ll : str
The column name for the lower limit of the confidence interval.
hl : str
The column name for the upper limit of the confidence interval.
logscale : bool
Whether to use a logarithmic scale for the x-axis.
flush : bool
Flag to align the y-tick labels.
annoteheaders : Optional[Union[Sequence[str], None]]
Additional annotations to be included in the plot.
rightannote : Optional[Union[Sequence[str], None]]
Annotations to be aligned to the right side of the plot.
right_annoteheaders : Optional[Union[Sequence[str], None]]
Headers for the right-aligned annotations.
ylabel : str
Label for the y-axis.
xlabel : str
Label for the x-axis.
yticker2 : Optional[str]
Additional y-tick labels.
color_alt_rows : bool
Whether to color alternate rows for better readability.
figsize : Union[Tuple, List]
Size of the figure to be created.
despine : bool, default=True
Whether to remove the top and right spines of the plot.
table : bool, default=False
Whether to draw a table-like structure on the plot.
**kwargs : Any
Additional keyword arguments for further customization.
Returns
-------
Matplotlib Axes object.
Axes
The matplotlib Axes object with the forest plot.
"""
_, ax = plt.subplots(figsize=figsize, facecolor="white")
ax = draw_ci(
Expand Down
53 changes: 53 additions & 0 deletions tests/test_mplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pandas as pd
from matplotlib.pyplot import Axes

from forestplot import mforestplot

dataname = "sleep-mmodel"
data = f"https://raw.githubusercontent.com/lsys/forestplot/mplot/examples/data/{dataname}.csv"
df = pd.read_csv(data)


std_opts = dict(
dataframe=df,
estimate="coef",
ll="ll",
hl="hl",
varlabel="var",
model_col="model",
)


def test_vanilla_mplot():
ax = mforestplot(**std_opts)
assert isinstance(ax, Axes)

df_processed, ax = mforestplot(**std_opts, return_df=True)
assert isinstance(ax, Axes)
assert isinstance(df_processed, pd.DataFrame)


def test_more_options():
df_processed, ax = mforestplot(
**std_opts,
color_alt_rows=True,
groupvar="group",
table=True,
rightannote=["var", "group"],
right_annoteheaders=["Variable", "Variable group"],
xlabel="Coefficient (95% CI)",
modellabels=["Have young kids", "Full sample"],
xticks=[-1200, -600, 0, 600],
return_df=True,
# Additional kwargs for customizations
**{
"marker": "D", # set maker symbol as diamond
"markersize": 35, # adjust marker size
"xlinestyle": (0, (10, 5)), # long dash for x-reference line
"xlinecolor": "#808080", # gray color for x-reference line
"xtick_size": 12, # adjust x-ticker fontsize
"despine": False,
},
)
assert isinstance(ax, Axes)
assert isinstance(df_processed, pd.DataFrame)

0 comments on commit 33bf8d0

Please sign in to comment.