# Data Visualisation


## Introduction

Here you'll see how to make charts to present data in an engaging and informative way.

There are a plethora of options (and packages) for data visualisation using code. First, though a note about the different *philosophies* of data visualisation. There are broadly two categories of approach to using code to create data visualisations: *imperative*, where you build what you want, and *declarative*, where you say what you want. Choosing which to use involves a trade-off: imperative libraries offer you flexibility but at the cost of some verbosity; declarative libraries offer you a quick way to plot your data, but only if it's in the right format to begin with, and customisation may be more difficult.

In the rest of this chapter, we'll take a look at making visualisations with several of these libraries. But first, let's introduce them.

The most important and widely used data visualisation library in Python is [**matplotlib**](https://matplotlib.org/). It was used to make the first image of a black hole {cite}`bhpaper` and to image the first empirical evidence of gravitational waves {cite}`gwpaper`. **matplotlib** is an imperative visualisation library: you specify each part of what you want individually to build up an entire picture. It's perhaps the easiest to learn and the most difficult to master. As well as making plots, it can also be used to make diagrams and animations.

[**seaborn**](https://seaborn.pydata.org/index.html) is a popular declarative library that builds on **maplotlib** and works especially well with data that are in a *tidy* format (one row per observation, one column per variable).

[**plotnine**](https://plotnine.readthedocs.io/en/latest/) is another declarative plotting library but, rather than having lots of different functions (eg 'boxplot', 'violinplot', 'scatterplot') as **seaborn** does, it adopts a *grammar of graphics* approach. What this means is that *all* visualisations begin with the same command, `ggplot`, and are combinations of layers that address different aspects of a plot, for example points or lines, scale, labels, and so on. It's a bit difficult to grasp this idea in the abstract, but we'll see a concrete example soon enough.

[**altair**](https://altair-viz.github.io/) is yet another declarative plotting library for Python! It's most suited to interactive graphics on the web, and produces really beautiful charts. Under the hood, it calls a javascript library named **Vega-Lite** that's the sort of thing newspaper data visualisation units might use to make their infographics.

**pandas** also has built-in plotting functions that you will have seen in the data analysis part of this book. They are of the form `df.plot.*` where `*` could be, for example, `scatter`. These are convenience functions for making a quick plot of your data and they actually use **matplotlib**.

## Getting started

We're going to start by learning a little bit more about each plot by seeing the same example--a scatterplot--in each of them. 

### Matplotlib

Each of these libraries has their own default theme for how it shows plots. Personally, I'm not much of a fan of the **matplotlib** default so the first thing I do is to change it to something more aesthetically pleasing and that uses the kind of fonts you'll see in a paper compiled with latex.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Set seed for reproducibility
np.random.seed(10)
# Set max rows displayed for readability
pd.set_option('display.max_rows', 6)
# Plot settings
plot_style = {'xtick.labelsize': 20,
              'ytick.labelsize': 20,
              'font.size': 22,
              'figure.autolayout': True,
              'figure.figsize': (10, 5.5),
              'axes.titlesize': 22,
              'axes.labelsize': 20,
              'lines.linewidth': 4,
              'lines.markersize': 6,
              'legend.fontsize': 16,
              'mathtext.fontset': 'stix',
              'font.family': 'STIXGeneral',
              'legend.frameon': False}
plt.style.use(plot_style)

You'll see we imported `matplotlib.pyplot as plt` above; this is the main part of **matplotlib** that we'll use in practice.

There are actually two ways to use **matplotlib**: the 'pyplot API' and the 'object-oriented API'. The pyplot API (application programming interface) gives an experience that's much closer to Matlab and can be useful to make a quick chart. It's very simple to use, for example:

In [None]:
plt.plot([1, 2, 3, 4], [1, 4, 9, 16])

```{admonition} Tip
:class: tip
**Matplotlib** returns an object when used, eg `[<matplotlib.lines.Line2D...` above. To suppress this, end the command with a semi-colon, `;`.
```

However, the pyplot API is generally less-flexible and less-useful than the object-oriented API, and for the rest of this chapter we'll be using the object-oriented API.

The object-oriented API is most often used by creating two objects: the figure and the axes. You should think of the figure object, `fig`, as the canvas on which you will put any number of charts. Each ax (short of 'axis') object is one chart within a figure. Of course, most of the time you're likely only to have one axis per figure, but in the cases when you don't it's a really useful setup. The plotting of elements such as lines, points, bars, and so, are controlled by the `ax` objects while the overall settings are controlled by `fig`.

Let's see an example of a line chart using this object-oriented approach.

In [None]:
fig, ax = plt.subplots()  # Create a figure containing a single axes.
ax.scatter([1, 2, 3, 4, 5, 6], [1, 4, 2, 3, 1, 7], s=150, c='b');  # Plot some data on the axes.

In the above example, we used `ax.scatter` to get a scatter plot, `s=150` to set the area of the points, and `c='b'` to set the color. Many of these features will accept an array instead of a single value and will map them into the plot in the way you'd expect, for instance:

In [None]:
fig, ax = plt.subplots()
ax.scatter([1, 2, 3, 4, 5, 6], [1, 4, 2, 3, 1, 7],
           s=np.linspace(300, 2000, 6),
           c=['b', 'r', 'g', 'k', 'cyan', 'yellow'],
           edgecolors='k',
           alpha=0.5);

Here we asked for a different colour for each point, an area that's increasing linearly, partly transparent points (default is `alpha=1`, which is a solid colour), and a black edge colour.

You can probably begin to see how everything is going to be customisable. We've only seen aspects of the plot that are customisable through the `scatter` keyword so far though; let's now see an example that's a bit more real (and useful!) in which we'll want to add labels, a title, and more. We'll use the Midwest demographics dataset.

In [None]:
df = (pd.read_csv('https://vincentarelbundock.github.io/Rdatasets/csv/ggplot2/midwest.csv', index_col='PID')
      .drop('Unnamed: 0', axis=1))
df.head()

In [None]:
fig, ax = plt.subplots()
ax.scatter(df['area'], df['poptotal'], edgecolors='k', alpha=0.6)
ax.set_xlim(0, 0.1)
ax.set_ylim(0, 1e6)
ax.set_xlabel('Area')
ax.set_ylabel('Population')
ax.set_title('Area vs. Population', loc='right');

Perhaps you can see from this already that **matplotlib** offers a ton of customisation features *and* that it can be quite verbose.

Let's see another couple of customisations that are really useful: formatting axes. On the y-axis, we'll get rid of the awkward computer science notation for one million (1e6) and, on the x-axis, we'll add a percentage suffix on the numbers plus some minor tick marks.

In [None]:
from matplotlib.ticker import AutoMinorLocator

fig, ax = plt.subplots()
ax.scatter(df['area'], df['poptotal'], edgecolors='k', alpha=0.6)
ax.set_xlim(0, 0.1)
ax.set_ylim(0, 1e6)
ax.set_xlabel('Area')
ax.set_ylabel('Population')
ax.set_title('Area vs. Population', loc='right')
ax.xaxis.set_minor_locator(AutoMinorLocator(4))
ax.xaxis.set_major_formatter('{x:.2f}%')
ax.ticklabel_format(style='sci', scilimits=(-2, 2), axis='y', useMathText=True);

The `AutoMinorLocator(4)` inserts 4 minor tick marks between each major tick mark. The major formatter is `'{x:.2f}%'`, which says print 2 decimal places followed by a % sign. Finally, the `ax.ticklabel_format` option changes the computer science notation to use latex to print a prettier form of the one million signifier.

Rather than go through all of the many, many options for customisation, the figure below gives an overview of the options:

![Anatomy of a matplotlib figure](https://matplotlib.org/_images/anatomy.png)

While **matplotlib** is super-customisable, sometimes achieving what you want directly can be a bit verbose. Let's say we want to now differentiate these points with colour according to which state they belong to and add a legend that says which states have which colour. The easiest way to do this is by creating a for loop.

In [None]:
colours = plt.get_cmap('Dark2')(np.linspace(0, 1, len(df['state'].unique())))

fig, ax = plt.subplots()
for i, state in enumerate(df['state'].unique()):
    xf = df.loc[df['state'] == state]
    ax.scatter(xf['area'], xf['poptotal'], color=colours[i], label=state, s=100, edgecolor='k', alpha=0.8)
ax.set_xlim(0, 0.1)
ax.set_ylim(0, 1e6)
ax.set_xlabel('Area')
ax.set_ylabel('Population')
ax.set_title('Area vs. Population', loc='center')
ax.xaxis.set_minor_locator(AutoMinorLocator(4))
ax.xaxis.set_major_formatter('{x:.2f}%')
ax.ticklabel_format(style='sci', scilimits=(-2, 2), axis='y', useMathText=True)
ax.legend(title='State');

Okay, so we managed to get what we wanted. We used a colormap to get 5 qualitatively different colours; there are also sequential colormaps for continuous (as opposed to discrete) variables. You can find out more about the colormaps available in base matplotlib [here](https://matplotlib.org/3.1.0/tutorials/colors/colormaps.html). We also subsetted the dataframe and looped over it by state. Matplotlib doesn't always work so well with tidy data. As you saw in this example, we had to loop over the different states. Some of its defaults are a little more friendly when using data that is unstacked, eg with one state per column.

However, this was quite verbose; when we come to **seaborn**, we'll see a much easier way of doing this. 

For now, let's see a couple of other plot types that are useful, beginning with contour plots:

In [None]:
def f(x, y):
    return np.sin(x) ** 10 + np.cos(10 + y * x) * np.cos(x)

x = np.linspace(0, 5, 100)
y = np.linspace(0, 5, 100)

X, Y = np.meshgrid(x, y)
Z = f(X, Y)

fig, ax = plt.subplots()
cf = ax.contourf(X, Y, Z, cmap='plasma')
ax.set_title(r'$f(x,y) = \sin^{10}(x) + \cos(x)\cos\left(10 + y\cdot x\right)$')
cbar = fig.colorbar(cf)

This demonstrates creating a heatmap (or contour plot) with a colour bar legend and a title that's rendered with latex. The heatmap uses a perceptually uniform distribution that makes equal changes look equal; **matplotlib** has a few of these. If you need more colours, check out the packages [**colorcet**](https://colorcet.holoviz.org/) and [**palettable**](https://jiffyclub.github.io/palettable/).

You can do some really quite amazing things using **matplotlib**'s build what you want philosophy. For instance, you can [build a timeline](https://matplotlib.org/3.1.1/gallery/lines_bars_and_markers/timeline.html#sphx-glr-gallery-lines-bars-and-markers-timeline-py), make [charts with polar axes](https://matplotlib.org/3.1.1/gallery/pie_and_polar_charts/polar_bar.html#sphx-glr-gallery-pie-and-polar-charts-polar-bar-py), and create [XKCD style plots](https://matplotlib.org/3.1.1/gallery/showcase/xkcd.html#sphx-glr-gallery-showcase-xkcd-py).

One basic bit of functionality that you might need is to put more than one type of information on a single plot. Using the object-oriented API, this is as simple as calling another method on an `ax` that you've already created. In the example below, we'll call `ax.hist` followed by `ax.plot` to get a the theoretical curve for a normal distribution (aka Gaussian) overlaid on a kernel density estimate based on many draws from the relevant distribution using `numpy`.

In [None]:
rand_draws = np.random.randn(5000)
grid_x = np.linspace(-5, 5, 1000)

fig, ax = plt.subplots()
ax.hist(rand_draws, bins=50, density=True)
ax.plot(grid_x, 1 / np.sqrt(2*np.pi) * np.exp(-(grid_x**2)/2), linewidth=4);

Another important class of plots that we should look at are line charts, especially time series. Let's grab the real GDP time series for the UK and US and look at the year-on-year quarterly growth rates (the first few entries will be `NaN` because we switched to growth space).

In [None]:
import pandas_datareader.data as web

ts_start_date = pd.to_datetime('1999-01-01')

df = pd.concat([web.DataReader('ticker=RGDP' + x, 'econdb', start=ts_start_date) for x in ['US', 'UK']], axis=1)
df.columns = ['US', 'UK']
df.index.name = 'Date'
df = 100*df.pct_change(4)
df.head()

Okay, now the quick way to plot this would be to simply call `df.plot()`. That works fine and doesn't look too bad: it recognises that it's dealing with a datetime and plots sensible tick labels on the x-axis, and it produces a legend for each of the two time series. But let's say we want to make that plot quickly *and* do some fine-tuning with **matplotlib**: how can we? The answer is to create a `fig` and an `ax` and then to ask our dataframe to use the `ax` we already created to plot the data. We can then carry on using the `ax` object in any way we like.

In [None]:
fig, ax = plt.subplots()
df.plot(ax=ax)
ax.set_title('Real GDP growth, %', loc='right')
ax.yaxis.tick_right(); # Put tick marks and tick labels on right-hand side

We could also have created this chart by looping through the different countries just as we looped through the difference states in the previous example.

Let's also see what happens when we use different limits:

In [None]:
fig, ax = plt.subplots()
df.plot(ax=ax)
ax.set_title('Real GDP growth, %', loc='right')
ax.yaxis.tick_right() # Put tick marks and tick labels on right-hand side
ax.set_xlim(pd.to_datetime('2018-01-01'), pd.to_datetime('2021-01-01'));

As you can see, the plot dynamically responded to the shorter time period by putting more details in, here of quarters. You can specify exactly what you want with the tick label formatters that cater to datetimes, but the defaults are pretty well-behaved.

Sometimes, you want to show different lines in different panels (or facets, as they are also known). You can do that with **matplotlib** too, though once again it's a build-your-own affair.

Let's put these two time series in their own facets to see how it works.

In [None]:
fig, axes = plt.subplots(1, 2)
for i, ax in enumerate(axes):
    ax.plot(df.index, df.iloc[:, i])
    ax.set_title(df.columns[i], loc='left')
    ax.yaxis.tick_right()
    ax.set_ylim(df.min().min(), df.max().max())
fig.suptitle('Real GDP growth, %', y=0.93);

Quite a few things are happening here. The first is that we asked for more rows and columns from matplotlib using `plt.subplots(nrows, ncolumns)` and instead of returning a single `ax` it returned a list of axes (here of length 2). We then iterated over those axes using enumerate, which gives an integer `i` and an `ax` from the list, and plotted a column on each axis by subsetting the dataframe one column at a time using `df.iloc[:, i]`. Remember that the first position in `iloc` is for the index values (ie the rows), while the second is for the columns. Again, this was quite a lot of code to get a relatively simple chart done! And that's exactly why we're now going to take a look at the declarative library **seaborn** that wraps **matplotlib**.

### Seaborn

**seaborn** provides a high-level interface for quickly drawing standard charts. It is based on **matplotlib*, which is great because it means you can always tinker if you need to. And it plays very nicely with **pandas** dataframes too, so it can easily fit into your workflow.

Let's see the chart we were just trying to make but rendered in **seaborn**. The first difference is that **seaborn** expects *tidy data*, a concept from the Data part of this book. So first we must transform our data into tidy format.



In [None]:
tidy_df = df.stack().reset_index()
tidy_df.columns = ['Date', 'Country', 'Real GDP growth, %']
tidy_df.head()

Okay now let's use seaborn to do the plotting. As with **matplotlib**, the way to get different charts in **seaborn** is a matter of knowing the API names. Here, we'll use `relplot`, which is for facets.

In [None]:
import seaborn as sns

sns.relplot(
    data=tidy_df,
    x="Date", y="Real GDP growth, %",
    hue="Country", col="Country",
    kind="line", legend=False,
);

In this cae, we got a very similar plot with less effort--and what we do is more generalisable too. As long as we have data in a tidy format, we can specify the columns to seaborn and let it put them in different facets. And facets aren't the only difference it can provide either: you can specify the hue such that it depends on another categorical variable (rather than, as in this case, changing with each facet). Here's an example changing the hue to reflect an extra type of information, region:


In [None]:
# Download RGDP for Canada and France:
extra_df = pd.concat([web.DataReader('ticker=RGDP' + x, 'econdb', start=ts_start_date) for x in ['CA', 'FR']], axis=1)
extra_df.columns = ['CA', 'FR']
extra_df.index.name = 'Date'
extra_df = (100*extra_df.pct_change(4)).stack().reset_index()
extra_df.columns = ['Date', 'Country', 'Real GDP growth, %']
# Add the new data to the tidy dataframe
tidy_df = pd.concat([tidy_df, extra_df], axis=0)
continent_dict = {'CA': 'North America', 'US': 'North America',
                  'UK': 'Europe', 'FR': 'Europe'}
tidy_df['Region'] = tidy_df['Country'].map(continent_dict)
tidy_df.sample(6)

In [None]:
sns.relplot(
    data=tidy_df,
    x="Date", y="Real GDP growth, %",
    hue="Region", col="Country", col_wrap=2,
    kind="line", legend=False,
);

As you can see, we now have two colours: one for each region.

Let's see some of the other useful shortcuts that **seaborn** provides over the top of **matplotlib**. First, correlation heatmaps:


In [None]:
# Generate data and create corr mat
d = pd.DataFrame(data=np.random.normal(size=(100, 6)),
                 columns=['A', 'B', 'C', 'D', 'E', 'F'])
corr = d.corr()

# Generate a mask to cover the upper triangle
mask = np.triu(np.ones_like(corr, dtype=bool))

# Draw heatmap
sns.heatmap(corr, mask=mask, cmap='magma', vmin=-.4, vmax=.4, center=0,
            square=True, linewidths=.5, cbar_kws={"shrink": .5});

**seaborn** also offers convenience functions for (continuous) heatmaps, kernel density estimates, and marginal plots:


In [None]:
df = sns.load_dataset('penguins')

g = sns.JointGrid(data=df, x="body_mass_g", y="bill_depth_mm", space=0)
g.plot_joint(sns.kdeplot,
             fill=True, clip=((2200, 6800), (10, 25)),
             thresh=0, levels=100, cmap="inferno")
g.plot_marginals(sns.histplot, color="#03051A", alpha=1, bins=25);

And plotting simple linear models according to a category:

In [None]:
g = sns.lmplot(
    data=df,
    x="bill_length_mm", y="bill_depth_mm", hue="species",
    legend=False, line_kws={'lw': 1.}, height=5, aspect=1.75
)
plt.legend(loc='lower left', frameon=True);

And, finally, violin plots:

In [None]:
tips = sns.load_dataset("tips")
sns.violinplot(data=tips, x="day", y="total_bill", hue="smoker",
               split=True, inner="quart", linewidth=1);