Skip to content

Commit

Permalink
Merge pull request #100 from LSYS/mplot-dev
Browse files Browse the repository at this point in the history
Fix `mcolor` arg
  • Loading branch information
LSYS committed Jan 14, 2024
2 parents 33bf8d0 + c1d28c8 commit 05c1c05
Show file tree
Hide file tree
Showing 5 changed files with 410 additions and 111 deletions.
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Additional options allow easy addition of columns in the `dataframe` as annotati
> - [Quick Start](#quick-start)
> - [Some Examples with Customizations](#some-examples-with-customizations)
> - [Gallery and API Options](#gallery-and-api-options)
> - [Multi-models](#multi-models)
> - [Known Issues](#known-issues)
> - [Background and Additional Resources](#background-and-additional-resources)
> - [Contributing](#contributing)
Expand Down Expand Up @@ -257,6 +258,52 @@ fp.forestplot(df, # the dataframe with results data
</details>
<p align="right">(<a href="#top">back to top</a>)</p>

<!------------------- Multi-models ------------------->
## Multi-models[![](https://raw.githubusercontent.com/LSYS/forestplot/main/docs/images/pin.svg)](#multi-models)

```python
import forestplot as fp

df_mmodel = pd.read_csv("../examples/data/sleep-mmodel.csv").query(
"model=='all' | model=='young kids'"
)
df_mmodel.head(3)
```

| | var | coef | se | T | pval | r2 | adj_r2 | ll | hl | model | group | label |
|---:|:------|-----------:|---------:|----------:|---------:|---------:|-----------:|-----------:|--------:|:-----------|:--------------|:------------|
| 0 | age | 0.994889 | 1.96925 | 0.505213 | 0.613625 | 0.127289 | 0.103656 | -2.87382 | 4.8636 | all | age | in years |
| 3 | age | 22.634 | 15.4953 | 1.4607 | 0.149315 | 0.178147 | -0.0136188 | -8.36124 | 53.6293 | young kids | age | in years |
| 4 | black | -84.7966 | 82.1501 | -1.03222 | 0.302454 | 0.127289 | 0.103656 | -246.186 | 76.5925 | all | other factors | =1 if black |

```python
fp.mforestplot(
dataframe=df_mmodel,
estimate="coef",
ll="ll",
hl="hl",
varlabel="label",
capitalize="capitalize",
model_col="model",
color_alt_rows=True,
groupvar="group",
table=True,
rightannote=["var", "group"],
right_annoteheaders=["Source", "Group"],
xlabel="Coefficient (95% CI)",
modellabels=["Have young kids", "Full sample"],
xticks=[-1200, -600, 0, 600],
# Additional kwargs for customizations
**{
"markersize": 30,
# override default vertical space between models
"offset": 0.4,
},
)
```
<p align="left"><img width="80%" src="https://raw.githubusercontent.com/LSYS/forestplot/mplot-dev/docs/images/multimodel.png"></p>

Please note: This module is still experimental. See [this jupyter notebook](https://nbviewer.org/github/LSYS/forestplot/blob/mplot-dev/examples/test-multmodel-sleep.ipynb) for more examples and tweaks.

<!------------------- GALLERY AND API OPTIONS ------------------->
## Gallery and API Options[![](https://raw.githubusercontent.com/LSYS/forestplot/main/docs/images/pin.svg)](#gallery-and-api-options)
Expand Down
Binary file modified docs/images/group-grouporder-pvalue-sort-colorrows.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/multimodel.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
461 changes: 354 additions & 107 deletions examples/readme-examples.ipynb

Large diffs are not rendered by default.

13 changes: 9 additions & 4 deletions forestplot/mplot_graph_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Any, List, Optional, Sequence, Tuple, Union

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -158,7 +159,7 @@ def mdraw_est_markers(
_df = dataframe.query(f'{model_col}=="{modelgroup}"')
base_y_vector = np.arange(len(_df)) - offset / 2 - (offset / 2) * (n - 2)
_y = base_y_vector + (ix * offset)
ax.scatter(y=_y, x=_df[estimate], marker=msymbols[ix], color=mcolor[ix], s=markersize)
ax.scatter(y=_y, x=_df[estimate], marker=msymbols[ix], c=mcolor[ix], s=markersize)
return ax


Expand Down Expand Up @@ -280,9 +281,13 @@ def mdraw_legend(
leg_markersize = kwargs.get("leg_markersize", 8)
leg_artists = []
for ix, symbol in enumerate(msymbols):
leg_artists.append(
Line2D([0], [0], marker=symbol, color=mcolor[ix], markersize=leg_markersize)
)
try:
leg_artists.append(
Line2D([0], [0], marker=symbol, color=mcolor[ix], markersize=leg_markersize)
)
except IndexError:
warnings.warn("'msymbols' and 'mcolor' have different lengths.")
pass
# Handle position of legend
# bbox_to_anchor = kwargs.get("bbox_to_anchor", None)
if len(modellabels) <= 2:
Expand Down

0 comments on commit 05c1c05

Please sign in to comment.