From 34d68de4e4e5072f1053e90ad1017248de28b215 Mon Sep 17 00:00:00 2001 From: LucaCappelletti94 Date: Wed, 14 Apr 2021 17:13:56 +0200 Subject: [PATCH] Moved parameters to the barplots method --- barplots/barplot.py | 15 ++-- barplots/barplots.py | 155 ++++++++++++++++++++++++++++++---- tests/test_barplots.py | 5 +- tests/test_simple_barplots.py | 2 + tests/test_single.py | 4 +- 5 files changed, 155 insertions(+), 26 deletions(-) diff --git a/barplots/barplot.py b/barplots/barplot.py index dbf0f9b..99ef92c 100644 --- a/barplots/barplot.py +++ b/barplots/barplot.py @@ -1,3 +1,4 @@ +"""Module implementing plotting of a barplot.""" import pandas as pd from typing import List, Tuple, Dict, Union, Callable from matplotlib.colors import TABLEAU_COLORS, CSS4_COLORS @@ -174,13 +175,10 @@ def barplot( if facecolors is None: facecolors = dict(zip(levels[0], ("white",)*len(levels[0]))) - if sort_subplots is None: - def sort_subplots(x): return x + sorted_level = levels[0] - if sort_bars is None: - def sort_bars(x): return x - - sorted_level = sort_subplots(levels[0]) + if sort_subplots is not None: + sorted_level = sort_subplots(sorted_level) if subplots: titles = sorted_level @@ -212,7 +210,8 @@ def sort_bars(x): return x else: sub_df = df - sub_df = sort_bars(sub_df) + if sort_bars is not None: + sub_df = sort_bars(sub_df) plot_bars(ax, sub_df, bar_width, space_width, alphas, colors, index, vertical=vertical, min_std=min_std) @@ -285,7 +284,7 @@ def sort_bars(x): return x if letter: figure.text( - 0, 1, letter, + 0.01, 0.9, letter, horizontalalignment='center', verticalalignment='center', weight='bold', diff --git a/barplots/barplots.py b/barplots/barplots.py index a2fe3e0..12d9e3e 100644 --- a/barplots/barplots.py +++ b/barplots/barplots.py @@ -1,14 +1,18 @@ +"""Module implementing plotting of multiple barplots in parallel and sequential manner.""" from multiprocessing import Pool, cpu_count -from typing import Dict, List +from typing import Dict, List, Tuple, Callable, Union import pandas as pd from sanitize_ml_labels import sanitize_ml_labels from tqdm.auto import tqdm +from matplotlib.figure import Figure +from matplotlib.axis import Axis from .barplot import barplot -def _barplot(kwargs): +def _barplot(kwargs: Dict) -> Tuple[Figure, Axis]: + """Wrapper over barplot call to expand given kwargs.""" return barplot(**kwargs) @@ -21,11 +25,38 @@ def barplots( path: str = "barplots/{feature}.png", sanitize_metrics: bool = True, letters: Dict[str, str] = None, + bar_width: float = 0.3, + space_width: float = 0.3, + height: float = None, + dpi: int = 200, + min_std: float = 0, + min_value: float = None, + max_value: float = None, + show_legend: bool = True, + show_title: str = True, + legend_position: str = "best", + colors: Dict[str, str] = None, + alphas: Dict[str, float] = None, + facecolors: Dict[str, str] = None, + orientation: str = "vertical", + subplots: bool = False, + plots_per_row: Union[int, str] = "auto", + minor_rotation: float = 0, + major_rotation: float = 0, + unique_minor_labels: bool = False, + unique_major_labels: bool = True, + unique_data_label: bool = True, + auto_normalize_metrics: bool = True, + placeholder: bool = False, + scale: str = "linear", + custom_defaults: Dict[str, List[str]] = None, + sort_subplots: Callable[[List], List] = None, + sort_bars: Callable[[pd.DataFrame], pd.DataFrame] = None, use_multiprocessing: bool = True, verbose: bool = True, - **barplot_kwargs: Dict -): - """ +) -> Tuple[List[Figure], List[Axis]]: + """Returns list of the built figures and axes. + Plot barplots corresponding to given dataframe, grouping by mean and if required also standard deviation. @@ -55,13 +86,79 @@ def barplots( Use the name of the metric (the dataframe column) as key of the dictionary. This is sometimes necessary on papers. By default it is None, that is no letter to be shown. + bar_width: float = 0.3, + Width of the bar of the barplot. + height: float = None, + Height of the barplot. By default golden ratio of the width. + dpi: int = 200, + DPI for plotting the barplots. + min_std: float = 0.001, + Minimum standard deviation for showing error bars. + min_value: float = None, + Minimum value for the barplot. + max_value: float = 0, + Maximum value for the barplot. + show_legend: bool = True, + Whetever to show or not the legend. + If legend is hidden, the bar ticks are shown alternatively. + show_title: str = True, + Whetever to show or not the barplot title. + legend_position: str = "best", + Legend position, by default "best". + data_label: str = None, + Barplot's data_label. + Use None for not showing any data_label (default). + title: str = None, + Barplot's title. + Use None for not showing any title (default). + path: str = None, + Path where to save the barplot. + Use None for not saving it (default). + colors: Dict[str, str] = None, + Dict of colors to be used for innermost index of dataframe. + By default None, using the default color tableau from matplotlib. + alphas: Dict[str, float] = None, + Dict of alphas to be used for innermost index of dataframe. + By default None, using the default alpha. + orientation: str = "vertical", + Orientation of the bars. + Can either be "vertical" of "horizontal". + subplots: bool = False, + Whetever to slit the top indexing layer to multiple subplots. + plots_per_row: Union[int, str] = "auto", + If subplots is True, specifies the number of plots for row. + If "auto" is used, for vertical the default is 2 plots per row, + while for horizontal the default is 4 plots per row. + minor_rotation: float = 0, + Rotation for the minor ticks of the bars. + major_rotation: float = 0, + Rotation for the major ticks of the bars. + unique_minor_labels: bool = False, + Avoid replicating minor labels on the same axis in multiple subplots settings. + unique_major_labels: bool = True, + Avoid replicating major labels on the same axis in multiple subplots settings. + unique_data_label: bool = True, + Avoid replication of data axis label when using subplots. + auto_normalize_metrics: bool = True, + Whetever to apply or not automatic normalization + to the metrics that are recognized to be between + zero and one. For example AUROC, AUPRC or accuracy. + placeholder: bool = False, + Whetever to add a text on top of the barplots to show + the word "placeholder". Useful when generating placeholder data. + scale: str = "linear", + Scale to use for the barplots. + Can either be "linear" or "log". + custom_defaults: Dict[str, List[str]], + Dictionary to normalize labels. use_multiprocessing: bool = True, Whetever to use or not multiple processes. verbose:bool, Whetever to show or not the loading bar. - barplot_kwargs:Dict, - Kwargs parameters to pass to the barplot method. - Read docstring for barplot method for more information on the available parameters. + + Returns + --------------------- + Tuple with list of rendered figures and rendered axes. """ if groupby is not None: groupby = df.groupby(groupby).agg( @@ -80,14 +177,40 @@ def barplots( features = sanitize_ml_labels(features) tasks = [ - { - "df": groupby[original], - "title":title.format(feature=feature.replace("_", " ")), - "data_label":data_label.format(feature=feature.replace("_", " ")), - "path":path.format(feature=feature).replace(" ", "_").lower(), - "letter": letters.get(original, None), - **barplot_kwargs - } for original, feature in zip(original, features) + dict( + df=groupby[original], + title=title.format(feature=feature.replace("_", " ")), + data_label=data_label.format(feature=feature.replace("_", " ")), + path=path.format(feature=feature).replace(" ", "_").lower(), + letter=letters.get(original, None), + bar_width=bar_width, + space_width=space_width, + height=height, + dpi=dpi, + min_std=min_std, + min_value=min_value, + max_value=max_value, + show_legend=show_legend, + show_title=show_title, + legend_position=legend_position, + colors=colors, + alphas=alphas, + facecolors=facecolors, + orientation=orientation, + subplots=subplots, + plots_per_row=plots_per_row, + minor_rotation=minor_rotation, + major_rotation=major_rotation, + unique_minor_labels=unique_minor_labels, + unique_major_labels=unique_major_labels, + unique_data_label=unique_data_label, + auto_normalize_metrics=auto_normalize_metrics, + placeholder=placeholder, + scale=scale, + custom_defaults=custom_defaults, + sort_subplots=sort_subplots, + sort_bars=sort_bars, + ) for original, feature in zip(original, features) if not pd.isna(groupby[original]).any().any() and not len(groupby[original]) == 0 ] diff --git a/tests/test_barplots.py b/tests/test_barplots.py index 4b0185a..7b68930 100644 --- a/tests/test_barplots.py +++ b/tests/test_barplots.py @@ -3,6 +3,7 @@ import os import itertools from tqdm.auto import tqdm +import matplotlib.pyplot as plt def test_barplots(): @@ -118,7 +119,8 @@ def test_barplots(): barplots( **kwargs, path=path, - custom_defaults=custom_defaults, verbose=False) + custom_defaults=custom_defaults, verbose=False) + plt.close() def test_single_index(): @@ -130,3 +132,4 @@ def test_single_index(): path="{root}/{{feature}}.png".format(root=root), verbose=False ) + plt.close() diff --git a/tests/test_simple_barplots.py b/tests/test_simple_barplots.py index 4310a8c..9581963 100644 --- a/tests/test_simple_barplots.py +++ b/tests/test_simple_barplots.py @@ -1,5 +1,6 @@ import pandas as pd from barplots import barplots +import matplotlib.pyplot as plt def test_simple_barplots(): @@ -26,3 +27,4 @@ def test_simple_barplots(): subplots=True, space_width=1, ) + plt.close() diff --git a/tests/test_single.py b/tests/test_single.py index 30bc285..76bc515 100644 --- a/tests/test_single.py +++ b/tests/test_single.py @@ -1,5 +1,6 @@ import pandas as pd from barplots import barplots +import matplotlib.pyplot as plt def test_single_index(): @@ -10,4 +11,5 @@ def test_single_index(): "cell_line", path="{root}/{{feature}}.png".format(root=root), verbose=False - ) \ No newline at end of file + ) + plt.close()