Skip to content

Commit

Permalink
feat(plot-correlation): add function to plot correlation
Browse files Browse the repository at this point in the history
Closes #21.
  • Loading branch information
Axel Fahy committed Sep 27, 2019
1 parent 18ec6e3 commit d54bfaf
Show file tree
Hide file tree
Showing 12 changed files with 112 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Makefile
Expand Up @@ -3,7 +3,7 @@
.PHONY: all
all: test

.PHONY: all
.PHONY: build-python
build-python:
python setup.py sdist bdist_wheel

Expand Down
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -52,6 +52,7 @@ As of *v0.2*, plots are not yet tested in the travis build.

* 0.2.3
* ADD: Function ``normalization_pd`` to normalize a DataFrame.
* ADD: Function ``plot_correlation`` to plot the correlation of variables in a DataFrame.
* 0.2.2
* FIX: Function ``value_2_list`` renamed to ``kwargs_2_list``.
* ADD: Function ``value_2_list`` to cast a single value.
Expand Down
2 changes: 2 additions & 0 deletions bff/plot/__init__.py
@@ -1,6 +1,7 @@
"""Plot module of bff."""

from .plot import (
plot_correlation,
plot_counter,
plot_history,
plot_predictions,
Expand All @@ -10,6 +11,7 @@

# Public object of the module.
__all__ = [
'plot_correlation',
'plot_counter',
'plot_history',
'plot_predictions',
Expand Down
82 changes: 82 additions & 0 deletions bff/plot/plot.py
Expand Up @@ -13,6 +13,7 @@
import numpy as np
import pandas as pd
from pandas.plotting import register_matplotlib_converters
import seaborn as sns

import bff.fancy

Expand All @@ -23,6 +24,87 @@
register_matplotlib_converters()


def plot_correlation(df: pd.DataFrame, already_computed: bool = False,
method: str = 'pearson', title: str = 'Correlation between variables',
ax: plt.axes = None, rotation_xticks: Union[float, None] = 90,
rotation_yticks: Union[float, None] = None,
figsize: Tuple[int, int] = (13, 10), dpi: int = 80,
style: str = 'white',
**kwargs):
"""
Plot the correlation between variables of a pandas DataFrame.
The computing of the correlation can be done either in the
function or before.
Parameters
----------
df : pd.DataFrame
DataFrame with the values or the correlations.
already_computed : bool, default False
Set to True if the DataFrame already contains the correlations.
method : str, default 'pearson'
Type of normalization. See pandas.DataFrame.corr for possible values.
title : str, default 'Correlation between variables'
Title for the plot (axis level).
ax : plt.axes, default None
Axes from matplotlib, if None, new figure and axes will be created.
rotation_xticks : float or None, default 90
Rotation of x ticks if any.
rotation_yticks : float or None, default None
Rotation of x ticks if any.
Set to 90 to put them vertically.
figsize : Tuple[int, int], default (13, 10)
Size of the figure to plot.
dpi : int, default 80
Resolution of the figure.
style : str, default 'white'
Style to use for seaborn.axes_style.
The style is use only in this context and not applied globally.
**kwargs
Additional keyword arguments to be passed to the
`sns.heatmap` function from seaborn.
Returns
-------
plt.axes
Axes returned by the `plt.subplots` function.
"""
# Compute the correlation if needed.
if not already_computed:
df = df.corr(method=method)

# Generate a mask for the upper triangle.
mask = np.zeros_like(df, dtype=np.bool)
mask[np.triu_indices_from(mask)] = True # pylint: disable=unsupported-assignment-operation

# Generate a custom diverging colormap.
cmap = sns.diverging_palette(220, 10, as_cmap=True)

with sns.axes_style(style):
if ax is None:
__, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)

# Draw the heatmap with the mask and correct aspect ratio.
sns.heatmap(df, mask=mask, cmap=cmap, ax=ax, vmin=-1, vmax=1, center=0,
annot=True, square=True, linewidths=0.5,
cbar_kws={"shrink": 0.75}, **kwargs)

ax.set_title(title, fontsize=14)
# Style.
# Remove border on the top and right.
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# Only show ticks on the left and bottom spines
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
# Style of ticks.
plt.xticks(fontsize=10, alpha=0.7, rotation=rotation_xticks)
plt.yticks(fontsize=10, alpha=0.7, rotation=rotation_yticks)

return ax


def plot_counter(counter: Union[Counter, dict],
label_x: str = 'x', label_y: str = 'y',
title: str = 'Bar chart', width: float = 0.9,
Expand Down
1 change: 1 addition & 0 deletions doc/source/fancy.rst
Expand Up @@ -15,6 +15,7 @@ All of bff's functions.
bff.mem_usage_pd
bff.normalization_pd
bff.parse_date
bff.plot.plot_correlation
bff.plot.plot_counter
bff.plot.plot_history
bff.plot.plot_predictions
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Expand Up @@ -23,13 +23,14 @@
'Source Code': 'https://github.com/axelfahy/bff'
}
REQUIRES = [
'matplotlib==3.1.1',
'matplotlib==3.1.0',
'numpy==1.17.0',
'pandas==0.25.0',
'python-dateutil==2.8.0',
'pyyaml==5.1.2',
'scikit-learn==0.21.3',
'scipy==1.3.1',
'seaborn==0.9.0',
'typing==3.7.4'
]
CLASSIFIERS = [
Expand Down
Binary file added tests/baseline/test_plot_correlation.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/baseline/test_plot_correlation_with_ax.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/baseline/test_plot_counter.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/baseline/test_plot_counter_dict.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/baseline/test_plot_counter_horizontal.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
23 changes: 23 additions & 0 deletions tests/test_plot.py
Expand Up @@ -75,6 +75,29 @@ class TestPlot(unittest.TestCase):
# Dictionary for bar chart.
dict_to_plot = {'Red': 15, 'Green': 50, 'Blue': 24}

@pytest.mark.mpl_image_compare
def test_plot_correlation(self):
"""
Test of the `plot_correlation` function.
"""
ax = bplt.plot_correlation(self.data)
return ax.figure

@pytest.mark.mpl_image_compare
def test_plot_correlation_with_ax(self):
"""
Test of the `plot_correlation` function.
"""
# Create fake data for one of the plot.
df_tmp = pd.DataFrame({'x': [123, 27, 38, 45, 67], 'y': [456, 45.4, 32, 34, 90]})
df_corr = df_tmp.corr()

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20, 10), dpi=80)
bplt.plot_correlation(df_corr, already_computed=True, ax=axes[0],
rotation_xticks=0, title='Correlation between x and y')
bplt.plot_correlation(self.data, ax=axes[1], method='spearman')
return fig

@pytest.mark.mpl_image_compare
def test_plot_counter(self):
"""
Expand Down

0 comments on commit d54bfaf

Please sign in to comment.