Skip to content

Commit

Permalink
feat(plot-counter): add function to plot counter as bar plot
Browse files Browse the repository at this point in the history
  • Loading branch information
Axel Fahy committed Sep 5, 2019
1 parent 253a9ef commit 7b024a1
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 0 deletions.
2 changes: 2 additions & 0 deletions bff/plot/__init__.py
@@ -1,6 +1,7 @@
"""Plot module of bff."""

from .plot import (
plot_counter,
plot_history,
plot_predictions,
plot_series,
Expand All @@ -9,6 +10,7 @@

# Public object of the module.
__all__ = [
'plot_counter',
'plot_history',
'plot_predictions',
'plot_series',
Expand Down
136 changes: 136 additions & 0 deletions bff/plot/plot.py
Expand Up @@ -4,6 +4,7 @@
This module contains fancy plot functions.
"""
import logging
from collections import Counter
from typing import Sequence, Tuple, Union
import matplotlib as mpl
import matplotlib.lines as mlines
Expand All @@ -22,6 +23,141 @@
register_matplotlib_converters()


def plot_counter(counter: Union[Counter, dict],
label_x: str = 'x', label_y: str = 'y',
title: str = 'Bar chart', width: float = 0.9,
threshold: int = 0, vertical: bool = True,
ax: plt.axes = None,
rotation_xticks: Union[float, None] = None,
grid: Union[str, None] = 'y',
figsize: Tuple[int, int] = (14, 5), dpi: int = 80,
style: str = 'default', **kwargs) -> plt.axes:
"""
Plot the values of a counter as an bar plot.
Values above the ratio are written as text on top of the bar.
Parameters
----------
counter : collections.Counter or dictionary
Counter or dictionary to plot.
label_x : str, default 'x'
Label for x axis.
label_y : str, default 'y'
Label for y axis.
title : str, default 'Bar chart'
Title for the plot (axis level).
width : float, default 0.9
Width of the bar. If below 1.0, there will be space between them.
threshold : int, default = 0
Threshold above which the value is written on the plot as text.
By default, all bar have their text.
vertical : bool, default True
By default, vertical bar. If set to False, will plot using `plt.barh`
and inverse the labels.
ax : plt.axes, default None
Axes from matplotlib, if None, new figure and axes will be created.
loc : str or int, default 'best'
Location of the legend on the plot.
Either the legend string or legend code are possible.
rotation_xticks : float or None, default None
Rotation of x ticks if any.
Set to 90 to put them vertically.
grid : str or None, default 'y'
Axis where to activate the grid ('both', 'x', 'y').
To turn off, set to None.
figsize : Tuple[int, int], default (14, 5)
Size of the figure to plot.
dpi : int, default 80
Resolution of the figure.
style : str, default 'default'
Style to use for matplotlib.pyplot.
The style is use only in this context and not applied globally.
**kwargs
Additional keyword arguments to be passed to the
`plt.plot` function from matplotlib.
Returns
-------
plt.axes
Axes returned by the `plt.subplots` function.
Examples
--------
>>> from collections import Counter
>>> counter = Counter({'red': 4, 'blue': 2})
>>> plot_counter(counter, title='MyTitle', rotation_xticks=90)
"""
with plt.style.context(style):
if ax is None:
__, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)

labels, values = zip(*sorted(counter.items()))

indexes = np.arange(len(labels))

if vertical:
ax.bar(indexes, values, width, **kwargs)
else:
ax.barh(indexes, values, width, **kwargs)
label_x, label_y = label_y, label_x

ax.set_xlabel(label_x, fontsize=12)
ax.set_ylabel(label_y, fontsize=12)
ax.set_title(title, fontsize=14)

# Write the real value on the bar if above a given threshold.
counter_max = max(counter.values())
space = counter_max * 1.01 - counter_max
for i in ax.patches:
if vertical:
if i.get_height() > threshold:
ax.text(i.get_x() + i.get_width() / 2, i.get_height() + space,
f'{i.get_height():^{len(str(counter_max))},}',
ha='center', va='bottom',
fontsize=10, color='black', alpha=0.6)
else:
if i.get_width() > threshold:
ax.text(i.get_width() + space, i.get_y() + i.get_height() / 2,
f'{i.get_width():>{len(str(counter_max))},}',
ha='left', va='center',
fontsize=10, color='black', alpha=0.6)

# Style.
# Remove border on the top and right.
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# Set alpha on remaining borders.
ax.spines['left'].set_alpha(0.4)
ax.spines['bottom'].set_alpha(0.4)

# Remove ticks on y axis.
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('none')

# Draw tick lines on wanted axes.
if grid:
ax.axes.grid(True, which='major', axis=grid, color='black',
alpha=0.3, linestyle='--', lw=0.5)

# Style of ticks.
# Set a thousand separator axis.
ax.xaxis.set_major_formatter(
mpl.ticker.FuncFormatter(lambda x, p: f'{x:,.1f}')
)
ax.yaxis.set_major_formatter(
mpl.ticker.FuncFormatter(lambda x, p: f'{x:,.1f}')
)
if vertical:
plt.xticks(indexes, labels, fontsize=10, alpha=0.7, rotation=rotation_xticks)
plt.yticks(fontsize=10, alpha=0.7)
else:
plt.xticks(fontsize=10, alpha=0.7, rotation=rotation_xticks)
plt.yticks(indexes, labels, fontsize=10, alpha=0.7)

return ax


def plot_history(history: dict, metric: Union[str, None] = None,
title: str = 'Model history', axes: plt.axes = None,
loc: Union[str, int] = 'best', grid: Union[str, None] = None,
Expand Down
1 change: 1 addition & 0 deletions doc/source/fancy.rst
Expand Up @@ -13,6 +13,7 @@ All of bff's functions.
bff.idict
bff.mem_usage_pd
bff.parse_date
bff.plot.plot_counter
bff.plot.plot_history
bff.plot.plot_predictions
bff.plot.plot_series
Expand Down
Binary file added tests/baseline/test_plot_counter.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/baseline/test_plot_counter_dict.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/baseline/test_plot_counter_horizontal.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
37 changes: 37 additions & 0 deletions tests/test_plot.py
Expand Up @@ -5,6 +5,7 @@
Assertion and resulting images are tested.
"""
from collections import Counter
import unittest
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -68,6 +69,42 @@ class TestPlot(unittest.TestCase):
.drop(pd.date_range('2018-01-01 00:40', '2018-01-01 00:41', freq='S'))
.drop(pd.date_range('2018-01-01 00:57', '2018-01-01 00:59', freq='S'))
)
# Counter for bar chart.
counter = Counter({'xelqo': 3, 'nisqo': 397, 'bff': 7454, 'eszo': 300, 'hedo': 26,
'sevcyk': 13, 'ajet': 31, 'zero': 10, 'exudes': 4, 'frazzio': 2})
# Dictionary for bar chart.
dict_to_plot = {'Red': 15, 'Green': 50, 'Blue': 24}

@pytest.mark.mpl_image_compare
def test_plot_counter(self):
"""
Test of the `plot_counter` function.
"""
ax = bplt.plot_counter(self.counter)
return ax.figure

@pytest.mark.mpl_image_compare
def test_plot_counter_horizontal(self):
"""
Test of the `plot_counter` function.
Check the behaviour with `vertical=False`.
"""
ax = bplt.plot_counter(self.counter, vertical=False, threshold=300,
title='Bar chart of fake companies turnover [Bn.]',
label_x='Company', label_y='Turnover',
grid='x', rotation_xticks=45, figsize=(10, 7))
return ax.figure

@pytest.mark.mpl_image_compare
def test_plot_counter_dict(self):
"""
Test of the `plot_counter` function.
Check the behaviour when using a dictionary.
"""
ax = bplt.plot_counter(self.dict_to_plot, grid=None)
return ax.figure

def test_plot_history(self):
"""
Expand Down

0 comments on commit 7b024a1

Please sign in to comment.