(vis-matplotlib)=
# Powerful Data Visualisation with **Matplotlib**


## Introduction

Here you'll see how to make just about any plot you can possibly imagine using the ferociously powerful *imperative* graphing package, **matplotlib**. If you read on to the chapter on {ref}`vis-narrative`, you'll see just how flexible it is. **matplotlib** is the foundation stone of quite a few other data visualisation libraries, such as **plotlnine** and **seaborn**, which shows how much it can do!

It's worth saying that **matplotlib** has a very different philosophy than **seaborn** and **plotnine** though. Apart from being *imperative* rather than *declarative*, it also prefers unstacked data to tidy data: so, to create plots quickly, instead of every line you want to plot being stored in a single column called "country" as it would be using tidy data, it prefers *each* column to have a different country in.

For a more in-depth introduction to **matplotlib**, head over to [this tutorial](https://github.com/rougier/matplotlib-tutorial). This chapter is indebted to that tutorial, to the excellent [**matplotlib** documentation](https://matplotlib.org/stable/tutorials/index.html), and to the book *Scientific Visualization in Matplotlib* {cite:t}`rougier2021scientific`.

As ever, we'll start by importing some key packages and initialising any settings:

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Set seed for random numbers
seed_for_prng = 78557
prng = np.random.default_rng(seed_for_prng)  # prng=probabilistic random number generator

In [None]:
import matplotlib_inline.backend_inline

# Plot settings
plt.style.use(
    "https://github.com/aeturrell/coding-for-economists/raw/main/plot_style.txt"
)
matplotlib_inline.backend_inline.set_matplotlib_formats("svg")

# Set max rows displayed for readability
pd.set_option("display.max_rows", 6)

## Understanding Matplotlib

You'll see we imported `matplotlib.pyplot as plt` above; this is the main part of **matplotlib** that we'll use in practice.

### The **matplotlib** API

**matplotlib** {cite:t}`hunter2007matplotlib` has its origin in an attempt to replicate the plotting functionality of the paid programming language Matlab but has since outgrown these roots and we're going to show you how to use it in the most productive way. It is worth saying, though, that the **matplotlib** API (application programming interface) is huge so there's no way we'll be able to cover *everything* it can do.

We'l be using the 'object-oriented API' to use **matplotlib**, and this is the recommended way. This just means that we create objects, like *figures* and *axes*, that have state (they remember what you did to them).

The object-oriented API is most often used by creating two fundamental objects that are used by almost every chart: the figure and the axes. You should think of the figure object, `fig`, as the canvas on which you can put any number of charts. Each ax (short for 'axis') object is one chart within a figure. Of course, most of the time you're likely only to have one axis per figure, but in the cases when you don't it's a really useful setup. The plotting of elements such as lines, points, bars, and so, are controlled by the `ax` objects while the overall settings are controlled by `fig`.

The most simple chart we can think of would be a line plot. Let's see how to do that:

In [None]:
fig, ax = plt.subplots()  # Get the figure and one axis as a subplot
ax.plot(
    [1, 2, 3, 4, 5, 6],  # Add some data on the x and y axes
    [1, 4, 2, 3, 1, 7],
)


Notice that, unlike **seaborn**, **matplotlib** will happily accept raw data (but will work with dataframes too).

```{admonition} Tip
:class: tip
**Matplotlib** returns an object when used in certain contexts, eg it might return `[<matplotlib.lines.Line2D at...` above. To suppress this, end the command with a semi-colon, `;`, or call `plt.show()` as the last command.
```

Let's see an example of a scatter plot using this object-oriented approach. Note that we begin in the same way, with getting a figure and an axis. But now we're going to use `ax.scatter` instead, plus throw in a couple of extra settings (can you guess what they do?).

In [None]:
fig, ax = plt.subplots()        # Create a figure containing a single axes.
ax.scatter(
    [1, 2, 3, 4, 5, 6],  # Plot some data on the axes.
    [1, 4, 2, 3, 1, 7],
    s=150,
    c="b"
    );

`s=150` sets the area of the points (ie the size of each marker), and `c='b'` sets the colour. Many of these features will accept an *array* of values (like `s = [1, 2, 3, 4, 5, 6]`) instead of a single value and will map them into the plot in the way you'd expect. 

Let's see that in practice. For the sizes, we'll linearly increment the points between 300 and 2000 in 6 steps. For the colours, we'll just type out six distinct values.

In [None]:
fig, ax = plt.subplots()
ax.scatter(
    [1, 2, 3, 4, 5, 6],
    [1, 4, 2, 3, 1, 7],
    s=np.linspace(300, 2000, 6),
    c=["b", "r", "g", "k", "cyan", "yellow"],
    edgecolors="k",
    alpha=0.5,
);

We also asked for partly transparent points via the `alpha` setting (the default is `alpha=1`, which is a solid colour), and a black (`color='k'`) line-edge colour.

As ever, you can call help on a specific command in order to understand what options (keyword arguments) it accepts. And, in Visual Studio Code, you can just hover over the function or method name.

Here is what `help` returns when you run it on the `scatter` method:

In [None]:
help(ax.scatter)

There are a huge number of options for what to put on axes; the table below gives a guide to the most essential ones for 2D plots.

| Code | What it does |
|---|---|
| `Axes.plot` | Plot y versus x as lines and/or markers. |
| `Axes.errorbar` | Plot y versus x as lines and/or markers with attached errorbars. |
| `Axes.scatter` | A scatter plot of y vs x |
| `Axes.step` | Make a step plot. |
| `Axes.loglog` | Make a plot with log scaling on both the x and y axis. |
| `Axes.semilogx` | Make a plot with log scaling on the x axis. |
| `Axes.semilogy` | Make a plot with log scaling on the y axis. |
| `Axes.fill_between` | Fill the area between two horizontal curves. |
| `Axes.fill_betweenx` | Fill the area between two vertical curves. |
| `Axes.bar` | Make a bar plot. |
| `Axes.barh` | Make a horizontal bar plot. |
| `Axes.bar_label` | Label a bar plot. |
| `Axes.stem` | Create a stem plot. |
| `Axes.eventplot` | Plot identical parallel lines at the given positions. |
| `Axes.pie` | Plot a pie chart. |
| `Axes.stackplot` | Draw a stacked area plot. |
| `Axes.broken_barh` | Plot a horizontal sequence of rectangles. |
| `Axes.vlines` | Plot vertical lines at each x from ymin to ymax. |
| `Axes.hlines` | Plot horizontal lines at each y from xmin to xmax. |
| `Axes.axhline` | Add a horizontal line across the Axes. |
| `Axes.axhspan` | Add a horizontal span (rectangle) across the Axes. |
| `Axes.axvline` | Add a vertical line across the Axes. |
| `Axes.axvspan` | Add a vertical span (rectangle) across the Axes. |
| `Axes.axline` | Add an infinitely long straight line. |
| `Axes.fill` | Plot filled polygons. |
| `Axes.boxplot` | Draw a box and whisker plot. |
| `Axes.violinplot` | Make a violin plot. |
| `Axes.violin` | Drawing function for violin plots. |
| `Axes.bxp` | Drawing function for box and whisker plots. |
| `Axes.hexbin` | Make a 2D hexagonal binning plot of points x, y. |
| `Axes.hist` | Compute and plot a histogram. |
| `Axes.hist2d` | Make a 2D histogram plot. |
| `Axes.stairs` | A stepwise constant function as a line with bounding edges or a filled plot. |
| `Axes.contour` | Plot contour lines. |
| `Axes.contourf` | Plot filled contours. |

Now, it might not seem like it, but you already have the makings of a wide range of **matplotlib** plots. You just need to follow this recipe:

1. `fig, ax = plt.subplots()`
2. Choose what you want to put on your axis, for example `ax.scatter` or `ax.plot`
3. Put data into the method you chose in 2 in the required format (and remember you can check the documentation to get the right format; either using `help`, hovering over the method name in Visual Studio Code, or heading to the **matplotlib** documentation, where there are often also examples to look at---here are the [examples for hex bins](https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.hexbin.html#examples-using-matplotlib-axes-axes-hexbin)).


### Anatomy of a **matplotlib** graph

So, you've made a nice plot and now you want to tweak it. Well, **matplotlib** certainly has a LOT of choice when it comes to tweaking! Rather than go through all of the many, many options for customisation, the figure below (from the documentation) gives an overview of the options:

![Anatomy of a matplotlib figure](https://matplotlib.org/_images/anatomy.png)

Let's run through a few of the most important plot elements:

- Figure, or 'fig': the figure keeps track of all the child Axes that are on it and a smattering of 'special' artists (titles, figure legends, etc). A figure can contain any number of Axes, but will typically have at least one if it is to be interesting!

- Axes, or 'ax': this is the plot, the region of the image that traces out the data. A given Figure can contain many Axes, but a given Axes object can only be in one Figure. The Axes contains two (or three in the case of 3D) Axis objects (Axes and Axis are different things!) that record the data limits (you can override these via the `axes.Axes.set_xlim()` and `axes.Axes.set_ylim()` methods). Each Axes object has a title (set via `set_title()`), an x-label (set via `set_xlabel()`), and a y-label set via `set_ylabel()`). When you add, say, a line chart to an Axes object it appears as a Line2D object associated that that axis and it is created by calling a method on an Axes object.

- Axis: these are the number-line objects that control the limits of what the viewer can see. They also provide the means to access the ticks (the marks on the axis) and ticklabels (strings labeling the ticks). The location of the ticks is determined by a Locator object and the ticklabel strings are formatted by a Formatter. The combination of Locator and Formatter gives very fine control over the tick locations and labels.

- Markers: these are what are produced by scatter plots.

- Labels and titles: text to help the viewers of the chart make sense of what they're seeing.

- Legend: if needed, contains the key to understand the shapes or colours used in lines or markers or bars (or ...) that are on the chart.


### Some Examples of Customising Charts

#### Limits and Labels

We've only seen aspects of the plot that are customisable through the `scatter` keyword so far; let's now see an example that's a bit more real (and useful!) in which we'll want to add labels, a title, and more. Note that **matplotlib** supports LaTeX equations in text.

We'll plot some data from the US Midwest demographics dataset.


In [None]:
df = pd.read_csv(
    "https://vincentarelbundock.github.io/Rdatasets/csv/ggplot2/midwest.csv",
    index_col="PID",
).drop("Unnamed: 0", axis=1)
df.head()

Now, as well as a scatter, let's add some context to the chart:

In [None]:
fig, ax = plt.subplots()
ax.scatter(df["area"], df["poptotal"], edgecolors="k", alpha=0.6)  # Make a scatter plot on "ax"
ax.set_xlim(0, 0.1)  # Set the limits on the x-axis
ax.set_ylim(0, 1e6)  # Set the limits on the y-axis
ax.set_xlabel("Area")  # Set the x label
ax.set_ylabel("Population")  # Set the y label
ax.set_title("Area vs. Population", loc="right");  # Add a title and say where it should go

#### Formatting Axes

Now we have the makings of a basic chart, with axes labels and even a title. 

But what we know from the data is that "Area" is actually a percentage. We could represent this by doing

```python
ax.set_xlabel("Area, %")
```

but as **matplotlib** is infinitely customisable, there is another option too---changing the tick labels.

On the x-axis, we'll add a percentage suffix on the numbers plus some minor tick marks.

In [None]:
from matplotlib.ticker import AutoMinorLocator

fig, ax = plt.subplots()
ax.scatter(df["area"], df["poptotal"], edgecolors="k", alpha=0.6)
ax.set_xlim(0, 0.1)
ax.set_ylim(0, 1e6)
ax.set_xlabel("Area")
ax.set_ylabel("Population")
ax.set_title("Area vs. Population", loc="right")
ax.xaxis.set_minor_locator(AutoMinorLocator(4))  # Add minor tick marks, four between every major one
ax.xaxis.set_major_formatter("{x:.2f}%");  # Every x value has 2 decimal places and is followed by a '%' sign

The `AutoMinorLocator(4)` inserts 4 minor tick marks between each major tick mark. The major formatter is `'{x:.2f}%'`, which says print 2 decimal places followed by a % sign.


#### Using Loops to Achieve Customisation Effects

Let's say we want to now differentiate these points with colour according to which state they belong to and add a legend that says which states have which colour. The easiest way to do this is by creating a `for` loop.

In [None]:
fig, ax = plt.subplots()
for i, state in enumerate(df["state"].unique()):
    xf = df.loc[df["state"] == state]
    ax.scatter(
        xf["area"],
        xf["poptotal"],
        cmap="Dark2",  # Note that we have qualitative data so we use a colour map
        label=state,
        s=100,
        edgecolor="k",
        alpha=0.8,
    )
ax.set_xlim(0, 0.1)
ax.set_ylim(0, 1e6)
ax.set_xlabel("Area")
ax.set_ylabel("Population")
ax.set_title("Area vs. Population", loc="center")
ax.xaxis.set_minor_locator(AutoMinorLocator(4))
ax.xaxis.set_major_formatter("{x:.2f}%")
ax.ticklabel_format(style="sci", scilimits=(-2, 2), axis="y", useMathText=True)
ax.legend(title="State", loc="upper right");

Okay, so we managed to get what we wanted. We used a colormap to get 5 qualitatively different colours; there are also sequential colormaps for continuous (as opposed to discrete) variables. You can find out more about the colormaps available in base matplotlib [here](https://matplotlib.org/3.1.0/tutorials/colors/colormaps.html). We also subsetted the dataframe and looped over it by state. Matplotlib doesn't always work so well with tidy data. As you saw in this example, we had to loop over the different states. Some of the **matplotlib** defaults are a little more friendly when using data that is unstacked, eg with one state per column.

There was an alternative approach here, 

## A Tour of **matplotlib** plots

Let's see a couple of other plot types that are useful, beginning with contour plots:

### More Unusual Charts

In [None]:
def f(x, y):
    return np.sin(x) ** 10 + np.cos(10 + y * x) * np.cos(x)


x = np.linspace(0, 5, 100)
y = np.linspace(0, 5, 100)

X, Y = np.meshgrid(x, y)
Z = f(X, Y)

fig, ax = plt.subplots()
cf = ax.contourf(X, Y, Z, cmap="plasma")
ax.set_title(r"$f(x,y) = \sin^{10}(x) + \cos(x)\cos\left(10 + y\cdot x\right)$")
cbar = fig.colorbar(cf)

This demonstrates creating a heatmap (or contour plot) with a colour bar legend and a title that's rendered with latex. The heatmap uses a perceptually uniform distribution that makes equal changes look equal; **matplotlib** has a few of these. If you need more colours, check out the packages [**colorcet**](https://colorcet.holoviz.org/) and [**palettable**](https://jiffyclub.github.io/palettable/).

### Multiple Charts on One Axis

You can do some really quite amazing things using **matplotlib**'s build what you want philosophy. For instance, you can [build a timeline](https://matplotlib.org/3.1.1/gallery/lines_bars_and_markers/timeline.html#sphx-glr-gallery-lines-bars-and-markers-timeline-py), make [charts with polar axes](https://matplotlib.org/3.1.1/gallery/pie_and_polar_charts/polar_bar.html#sphx-glr-gallery-pie-and-polar-charts-polar-bar-py), and create [XKCD style plots](https://matplotlib.org/3.1.1/gallery/showcase/xkcd.html#sphx-glr-gallery-showcase-xkcd-py).

One basic bit of functionality that you might need is to put more than one type of information on a single plot. Using the object-oriented API, this is as simple as calling another method on an `ax` that you've already created. In the example below, we'll call `ax.hist` followed by `ax.plot` to get a the theoretical curve for a normal distribution (aka Gaussian) overlaid on a kernel density estimate based on many draws from the relevant distribution using `numpy`.

In [None]:
rand_draws = prng.standard_normal(5000)
grid_x = np.linspace(-5, 5, 1000)

fig, ax = plt.subplots()
ax.hist(rand_draws, bins=50, density=True)
ax.plot(grid_x, 1 / np.sqrt(2 * np.pi) * np.exp(-(grid_x ** 2) / 2), linewidth=4);

Another important class of plots that we should look at are line charts, especially time series. Let's grab the real GDP time series for the UK and US and look at the year-on-year quarterly growth rates (the first few entries will be `NaN` because we switched to growth space).

In [None]:
import pandas_datareader.data as web

ts_start_date = pd.to_datetime("1999-01-01")

df = pd.concat(
    [
        web.DataReader("ticker=RGDP" + x, "econdb", start=ts_start_date)
        for x in ["US", "UK"]
    ],
    axis=1,
)
df.columns = ["US", "UK"]
df.index.name = "Date"
df = 100 * df.pct_change(4)
df.head()

Okay, now the quick way to plot this would be to simply call `df.plot()`. That works fine and doesn't look too bad: it recognises that it's dealing with a datetime and plots sensible tick labels on the x-axis, and it produces a legend for each of the two time series. But let's say we want to make that plot quickly *and* do some fine-tuning with **matplotlib**: how can we? The answer is to create a `fig` and an `ax` and then to ask our dataframe to use the `ax` we already created to plot the data. We can then carry on using the `ax` object in any way we like.

In [None]:
fig, ax = plt.subplots()
df.plot(ax=ax, lw=3)
ax.set_title("Real GDP growth, %", loc="right")
ax.spines["right"].set_visible(True)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.yaxis.tick_right()  # Put tick marks and tick labels on right-hand side

We could also have created this chart by looping through the different countries just as we looped through the difference states in the previous example.

Let's also see what happens when we use different limits:

In [None]:
fig, ax = plt.subplots()
df.plot(ax=ax, lw=3)
ax.set_title("Real GDP growth, %", loc="right")
ax.spines["right"].set_visible(True)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.yaxis.tick_right()  # Put tick marks and tick labels on right-hand side
ax.set_xlim(pd.to_datetime("2018-01-01"), pd.to_datetime("2021-01-01"));

As you can see, the plot dynamically responded to the shorter time period by putting more details in, here of quarters. You can specify exactly what you want with the tick label formatters that cater to datetimes, but the defaults are pretty well-behaved.

Sometimes, you want to show different lines in different panels (or facets, as they are also known). You can do that with **matplotlib** too, though once again it's a build-your-own affair.

Let's put these two time series in their own facets to see how it works.

In [None]:
fig, axes = plt.subplots(1, 2)
for i, ax in enumerate(axes):
    ax.plot(df.index, df.iloc[:, i], lw=3)
    ax.set_title(df.columns[i], loc="left")
    ax.yaxis.tick_right()
    ax.spines["right"].set_visible(True)
    ax.spines["left"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.set_ylim(df.min().min(), df.max().max())
fig.suptitle("Real GDP growth, %")
plt.tight_layout();

Quite a few things are happening here. The first is that we asked for more rows and columns from matplotlib using `plt.subplots(nrows, ncolumns)` and instead of returning a single `ax` it returned a list of axes (here of length 2). We then iterated over those axes using enumerate, which gives an integer `i` and an `ax` from the list, and plotted a column on each axis by subsetting the dataframe one column at a time using `df.iloc[:, i]`. Remember that the first position in `iloc` is for the index values (ie the rows), while the second is for the columns. Again, this was quite a lot of code to get a relatively simple chart done! And that's exactly why we're now going to take a look at the declarative library **seaborn** that wraps **matplotlib**.

## Some Useful Charts for Economics

## Advanced Topics

### Creating a style that applies to all charts

