# StatsPlot
This notebook demonstrates the main features of StatsPlot API.

Full details of the API can be found in the [documentation](https://parici75.github.io/statsplot).


In [1]:
import pandas as pd
import numpy as np

import plotly.express as px
from plotly.subplots import make_subplots

import plotly.io as pio
pio.renderers.default = 'notebook'

import statsplot

## Heatmap plots
A cartesian plot with data visualized as colored rectangular tiles.

Note that the function does not operate on wide-form DataFrame, so the data needs to be melted.

In [2]:
df = px.data.medals_wide(indexed=True)

fig = statsplot.heatmap(data=df.melt(ignore_index=False),
                        x="nation",
                        y="medal",
                        z="value",
                        opacity=0.8,
                        color_palette=["#d4f542", "#4275f5"],
                        axis="square"
                       )
fig.show()

Compared to `plotly.express.imshow`, `statsplot.heatmap` offers the ability to slice the data along a particular dimension. This can be handy when one seeks to visually inspect subparts of correlations matrices :

In [3]:
df = px.data.iris().set_index('species')
corr_df = pd.concat([df.loc[species, ["sepal_length", 'sepal_width', "petal_length", "petal_width"]].corr()
           for species in df.index.unique()], keys=df.index.unique(),
                   names=["species", "index"])
corr_df.index = pd.MultiIndex.from_arrays((corr_df.index.get_level_values(0), pd.Index(["_".join(idx)for idx in corr_df.index], name="index")))

fig = statsplot.heatmap(
    data=corr_df.melt(ignore_index=False,
            value_name="correlation"),
    x="index",
    y="variable",
    z="correlation",
    color_palette="RdBu_r",
    colorbar=True,
    opacity=1,
    slicer="species",
    axis="equal"
)
fig.show()

One can also specify the base of a logscale for the colormap with the `logscale` parameter:

In [4]:
df = px.data.stocks().set_index('date')
df.index = pd.DatetimeIndex(df.index, yearfirst=True, name="date")

fig = statsplot.heatmap(
    data=df.melt(ignore_index=False,
                 var_name="company",
                 value_name='stock_value'),
    x="company",
    y="date",
    z='stock_value',
    logscale=10,
)
fig.show()

## Categorical plots
A plot to visualize data ordered along categories.

`statsplot` offer three varieties of visualization : stripplot, boxplot, and violinplot.

In [5]:
df = px.data.tips()

fig = statsplot.catplot(
    data=df,
    y="total_bill",
    x="sex",
    plot_type="stripplot",
    size="size",
    slicer='sex',
)
fig.show()

### violinplot

In [6]:
df = px.data.tips()

fig = statsplot.catplot(
    data=df,
    y="total_bill",
    x="sex",
    plot_type="violinplot",
    size=8,
    slicer='smoker',
)
fig.show()

### boxplot

Slices can be re-ordered and filtered by providing the `slice_order` argument.

In [7]:
df = px.data.tips()

fig = statsplot.catplot(data=df,
                        y="total_bill",
                        x="sex",
                        plot_type="boxplot",
                        slicer='day',
                        slice_order=["Fri", "Sat", "Sun"],
                       )
fig.show()

# Scatter plots
The most generic and flexible plot to visualize raw data.

Data for error bars can be supplied as a two dimensional array specifying the left/right ot bottom/up error/measure of variation to be depicted.

In [8]:
df = px.data.iris()
df["error_x"] = list(zip(df.sepal_width * 0.98, df.sepal_width * 1.01))

fig = statsplot.plot(data=df,
                     mode='markers',
                     marker="diamond",
                     x="sepal_width",
                     y="sepal_length",
                     slicer="species",
                     error_x="error_x",
                     axis='id_line',
                    )
fig.show()

Continuous error shade can be plotted on line plots, as well as regression lines :

In [9]:
df = px.data.gapminder()
df["shaded_error"] = list(zip(0.8 * df.gdpPercap,
                              1.5 * df.gdpPercap))  # Fake error shade

fig = statsplot.plot(
    data=df.sort_values("lifeExp"),
    y="gdpPercap",
    x="lifeExp",
    shaded_error="shaded_error",
    color_palette="tab20",
    slicer='country',
    fit='linear',
)
fig.show()

Supplying a z dimension generates a 3D scatter :

In [10]:
df = px.data.iris()
df["error_z"] = list(zip(df.petal_width * 0.96, df.petal_width * 1.05))

fig = statsplot.plot(
    data=df,
    x="sepal_length",
    y="sepal_width",
    z="petal_width",
    error_z="error_z",
    slicer='species',
    mode="markers",
)
fig.show()

Color coding can also be achieved via a discrete colormap. Just supply the numerical color dimension as a string.

In [11]:
df = px.data.iris()
df["species_id"] = df["species_id"].astype(str)

fig = statsplot.plot(
    data=df,
    x="sepal_length",
    y="sepal_width",
    z="petal_width",
    color="species_id",
    mode="markers",
)
fig.show()

## Bar plots
Each row of the DataFrame is represented as a rectangular mark.

In [12]:
df = px.data.stocks()

fig = statsplot.barplot(
    data=df.set_index("date").melt(ignore_index=False,
                                   var_name="company",
                                   value_name='stock_value'),
    y="stock_value",
    x="date",
    barmode='stack',
    slicer="company",
)
fig.show()

Color can be specified independently of the slicer by providing the `color` parameter :

In [13]:
df = px.data.medals_long()

fig = statsplot.barplot(
    data=df,
    barmode="group",
    x="nation",
    y="count",
    color="count",
    slicer='medal',
)
fig.show()

Data can be aggregated with a function specified as a `aggregation_func` argument. Error bars function can also be specified :

In [14]:
df = px.data.tips()
fig = statsplot.barplot(data=df,
                        y="total_bill",
                        x="day",
                        slicer='sex',
                        color_palette="tab10",
                        aggregation_func='mean',
                        error_bar="bootstrap")

fig.show()

In [15]:
df = px.data.tips()
fig = statsplot.barplot(data=df,
                        y="total_bill",
                        x="day",
                        slicer='sex',
                        aggregation_func='median',
                        error_bar="iqr")

fig.show()

# Histogram plots
A representation of the distribution of numerical data, where the data are binned and the count for each bin is represented.

The combination of `step`, `rug`, `hist` and `kde` parameters allows for fine control on the representation of the underlying distribution.

`hlines` and `vlines` parameters are convenient ways to graph horizontal or vertical lines attached to a slice of the data. This is useful to highlight particular values on the distribution.

In [16]:
df = px.data.tips()
fig = statsplot.distplot(
    data=df,
    y="total_bill",
    step=True,
    rug=True,
    hist=True,
    equal_bins=False,
    bins=20,
    color_palette="Set2",
    slicer='sex',
    central_tendency="mean",
    hlines={'Female': ("Actual value", 50)},

)
fig.show()

In [17]:
df = px.data.tips()
fig = statsplot.distplot(data=df,
                         x="total_bill",
                         rug=True,
                         kde=True,
                         hist=True,
                         equal_bins=True,
                         bins=20,
                         color_palette="Set2_r",
                         slicer='sex',
                         central_tendency="median")
fig.show()

df = px.data.tips()
fig = statsplot.distplot(data=df,
                         x="total_bill",
                         rug=True,
                         kde=True,
                         hist=True,
                         equal_bins=False,
                         bins=20,
                         color_palette="Set2_r",
                         central_tendency="median")
fig.show()

# Jointplots
A plot of two variables with bivariate and optionally univariate (i.e., marginal distribution) graphs.

In [18]:
df = px.data.tips()

fig = statsplot.jointplot(data=df,
                          x="total_bill",
                          y="tip",
                          step=False,
                          rug=True,
                          kde=True,
                          hist=False,
                          fit='linear',
                          plot_type='scatter',
                          equal_bins_x=True,
                          bins_x=20,
                          color_palette="Set2",
                          slicer='sex',
                          marginal_plot="all")
fig.show()

The underlying distribution can be plotted by using the `kde` keyword argument. As it is not possible to visualize density maps as overlays, a dropdown menu is created to switch between data slices.

In [19]:
df = px.data.tips()

fig = statsplot.jointplot(data=df,
                          x="total_bill",
                          y="tip",
                          step=False,
                          rug=True,
                          kde=True,
                          hist=False,
                          fit='linear',
                          plot_type='scatter+kde',
                          equal_bins_x=False,
                          bins_x=20,
                          color_palette="Set2",
                          kde_color_palette="greens",
                          slicer='sex',
                          shared_coloraxis=True,
                          marginal_plot="all"
                         )
fig.show()

In [23]:
df = px.data.tips()

fig = statsplot.jointplot(
    data=df,
    x="total_bill",
    y="tip",
    step=False,
    rug=True,
    kde=False,
    hist=True,
    plot_type='histogram',
    equal_bins_x=False,
    bins_x=20,
    color_palette="bone",
    kde_color_palette="reds",
    slicer='sex',
    marginal_plot="x"
)
fig.show()

The `histmap` plot type is useful to draw a distribution along unique values of another dimension. In the plot below, we display histogram of stock value for each month :

In [21]:
df = px.data.stocks().set_index('date')
df.index = pd.DatetimeIndex(df.index, yearfirst=True, name="date")

fig = statsplot.jointplot(data=df.melt(ignore_index=False,
                                       var_name="company",
                                       value_name='stock_value'),
                          y="stock_value",
                          x="date",
                          plot_type='x_histmap',
                          marginal_plot="y")
fig.show()