Skip to content

Commit

Permalink
Moved parameters to the barplots method
Browse files Browse the repository at this point in the history
  • Loading branch information
LucaCappelletti94 committed Apr 14, 2021
1 parent b1722e6 commit 34d68de
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 26 deletions.
15 changes: 7 additions & 8 deletions barplots/barplot.py
@@ -1,3 +1,4 @@
"""Module implementing plotting of a barplot."""
import pandas as pd
from typing import List, Tuple, Dict, Union, Callable
from matplotlib.colors import TABLEAU_COLORS, CSS4_COLORS
Expand Down Expand Up @@ -174,13 +175,10 @@ def barplot(
if facecolors is None:
facecolors = dict(zip(levels[0], ("white",)*len(levels[0])))

if sort_subplots is None:
def sort_subplots(x): return x
sorted_level = levels[0]

if sort_bars is None:
def sort_bars(x): return x

sorted_level = sort_subplots(levels[0])
if sort_subplots is not None:
sorted_level = sort_subplots(sorted_level)

if subplots:
titles = sorted_level
Expand Down Expand Up @@ -212,7 +210,8 @@ def sort_bars(x): return x
else:
sub_df = df

sub_df = sort_bars(sub_df)
if sort_bars is not None:
sub_df = sort_bars(sub_df)

plot_bars(ax, sub_df, bar_width, space_width, alphas, colors, index,
vertical=vertical, min_std=min_std)
Expand Down Expand Up @@ -285,7 +284,7 @@ def sort_bars(x): return x

if letter:
figure.text(
0, 1, letter,
0.01, 0.9, letter,
horizontalalignment='center',
verticalalignment='center',
weight='bold',
Expand Down
155 changes: 139 additions & 16 deletions barplots/barplots.py
@@ -1,14 +1,18 @@
"""Module implementing plotting of multiple barplots in parallel and sequential manner."""
from multiprocessing import Pool, cpu_count
from typing import Dict, List
from typing import Dict, List, Tuple, Callable, Union

import pandas as pd
from sanitize_ml_labels import sanitize_ml_labels
from tqdm.auto import tqdm
from matplotlib.figure import Figure
from matplotlib.axis import Axis

from .barplot import barplot


def _barplot(kwargs):
def _barplot(kwargs: Dict) -> Tuple[Figure, Axis]:
"""Wrapper over barplot call to expand given kwargs."""
return barplot(**kwargs)


Expand All @@ -21,11 +25,38 @@ def barplots(
path: str = "barplots/{feature}.png",
sanitize_metrics: bool = True,
letters: Dict[str, str] = None,
bar_width: float = 0.3,
space_width: float = 0.3,
height: float = None,
dpi: int = 200,
min_std: float = 0,
min_value: float = None,
max_value: float = None,
show_legend: bool = True,
show_title: str = True,
legend_position: str = "best",
colors: Dict[str, str] = None,
alphas: Dict[str, float] = None,
facecolors: Dict[str, str] = None,
orientation: str = "vertical",
subplots: bool = False,
plots_per_row: Union[int, str] = "auto",
minor_rotation: float = 0,
major_rotation: float = 0,
unique_minor_labels: bool = False,
unique_major_labels: bool = True,
unique_data_label: bool = True,
auto_normalize_metrics: bool = True,
placeholder: bool = False,
scale: str = "linear",
custom_defaults: Dict[str, List[str]] = None,
sort_subplots: Callable[[List], List] = None,
sort_bars: Callable[[pd.DataFrame], pd.DataFrame] = None,
use_multiprocessing: bool = True,
verbose: bool = True,
**barplot_kwargs: Dict
):
"""
) -> Tuple[List[Figure], List[Axis]]:
"""Returns list of the built figures and axes.
Plot barplots corresponding to given dataframe,
grouping by mean and if required also standard deviation.
Expand Down Expand Up @@ -55,13 +86,79 @@ def barplots(
Use the name of the metric (the dataframe column) as key of the dictionary.
This is sometimes necessary on papers.
By default it is None, that is no letter to be shown.
bar_width: float = 0.3,
Width of the bar of the barplot.
height: float = None,
Height of the barplot. By default golden ratio of the width.
dpi: int = 200,
DPI for plotting the barplots.
min_std: float = 0.001,
Minimum standard deviation for showing error bars.
min_value: float = None,
Minimum value for the barplot.
max_value: float = 0,
Maximum value for the barplot.
show_legend: bool = True,
Whetever to show or not the legend.
If legend is hidden, the bar ticks are shown alternatively.
show_title: str = True,
Whetever to show or not the barplot title.
legend_position: str = "best",
Legend position, by default "best".
data_label: str = None,
Barplot's data_label.
Use None for not showing any data_label (default).
title: str = None,
Barplot's title.
Use None for not showing any title (default).
path: str = None,
Path where to save the barplot.
Use None for not saving it (default).
colors: Dict[str, str] = None,
Dict of colors to be used for innermost index of dataframe.
By default None, using the default color tableau from matplotlib.
alphas: Dict[str, float] = None,
Dict of alphas to be used for innermost index of dataframe.
By default None, using the default alpha.
orientation: str = "vertical",
Orientation of the bars.
Can either be "vertical" of "horizontal".
subplots: bool = False,
Whetever to slit the top indexing layer to multiple subplots.
plots_per_row: Union[int, str] = "auto",
If subplots is True, specifies the number of plots for row.
If "auto" is used, for vertical the default is 2 plots per row,
while for horizontal the default is 4 plots per row.
minor_rotation: float = 0,
Rotation for the minor ticks of the bars.
major_rotation: float = 0,
Rotation for the major ticks of the bars.
unique_minor_labels: bool = False,
Avoid replicating minor labels on the same axis in multiple subplots settings.
unique_major_labels: bool = True,
Avoid replicating major labels on the same axis in multiple subplots settings.
unique_data_label: bool = True,
Avoid replication of data axis label when using subplots.
auto_normalize_metrics: bool = True,
Whetever to apply or not automatic normalization
to the metrics that are recognized to be between
zero and one. For example AUROC, AUPRC or accuracy.
placeholder: bool = False,
Whetever to add a text on top of the barplots to show
the word "placeholder". Useful when generating placeholder data.
scale: str = "linear",
Scale to use for the barplots.
Can either be "linear" or "log".
custom_defaults: Dict[str, List[str]],
Dictionary to normalize labels.
use_multiprocessing: bool = True,
Whetever to use or not multiple processes.
verbose:bool,
Whetever to show or not the loading bar.
barplot_kwargs:Dict,
Kwargs parameters to pass to the barplot method.
Read docstring for barplot method for more information on the available parameters.
Returns
---------------------
Tuple with list of rendered figures and rendered axes.
"""
if groupby is not None:
groupby = df.groupby(groupby).agg(
Expand All @@ -80,14 +177,40 @@ def barplots(
features = sanitize_ml_labels(features)

tasks = [
{
"df": groupby[original],
"title":title.format(feature=feature.replace("_", " ")),
"data_label":data_label.format(feature=feature.replace("_", " ")),
"path":path.format(feature=feature).replace(" ", "_").lower(),
"letter": letters.get(original, None),
**barplot_kwargs
} for original, feature in zip(original, features)
dict(
df=groupby[original],
title=title.format(feature=feature.replace("_", " ")),
data_label=data_label.format(feature=feature.replace("_", " ")),
path=path.format(feature=feature).replace(" ", "_").lower(),
letter=letters.get(original, None),
bar_width=bar_width,
space_width=space_width,
height=height,
dpi=dpi,
min_std=min_std,
min_value=min_value,
max_value=max_value,
show_legend=show_legend,
show_title=show_title,
legend_position=legend_position,
colors=colors,
alphas=alphas,
facecolors=facecolors,
orientation=orientation,
subplots=subplots,
plots_per_row=plots_per_row,
minor_rotation=minor_rotation,
major_rotation=major_rotation,
unique_minor_labels=unique_minor_labels,
unique_major_labels=unique_major_labels,
unique_data_label=unique_data_label,
auto_normalize_metrics=auto_normalize_metrics,
placeholder=placeholder,
scale=scale,
custom_defaults=custom_defaults,
sort_subplots=sort_subplots,
sort_bars=sort_bars,
) for original, feature in zip(original, features)
if not pd.isna(groupby[original]).any().any() and not len(groupby[original]) == 0
]

Expand Down
5 changes: 4 additions & 1 deletion tests/test_barplots.py
Expand Up @@ -3,6 +3,7 @@
import os
import itertools
from tqdm.auto import tqdm
import matplotlib.pyplot as plt


def test_barplots():
Expand Down Expand Up @@ -118,7 +119,8 @@ def test_barplots():

barplots(
**kwargs, path=path,
custom_defaults=custom_defaults, verbose=False)
custom_defaults=custom_defaults, verbose=False)
plt.close()


def test_single_index():
Expand All @@ -130,3 +132,4 @@ def test_single_index():
path="{root}/{{feature}}.png".format(root=root),
verbose=False
)
plt.close()
2 changes: 2 additions & 0 deletions tests/test_simple_barplots.py
@@ -1,5 +1,6 @@
import pandas as pd
from barplots import barplots
import matplotlib.pyplot as plt


def test_simple_barplots():
Expand All @@ -26,3 +27,4 @@ def test_simple_barplots():
subplots=True,
space_width=1,
)
plt.close()
4 changes: 3 additions & 1 deletion tests/test_single.py
@@ -1,5 +1,6 @@
import pandas as pd
from barplots import barplots
import matplotlib.pyplot as plt


def test_single_index():
Expand All @@ -10,4 +11,5 @@ def test_single_index():
"cell_line",
path="{root}/{{feature}}.png".format(root=root),
verbose=False
)
)
plt.close()

0 comments on commit 34d68de

Please sign in to comment.