(vis-matplotlib)=
# Powerful Data Visualisation with **Matplotlib**


## Introduction

Here you'll see how to make just about any plot you can possibly imagine using the ferociously powerful *imperative* graphing package, **matplotlib**. If you read on to the chapter on {ref}`vis-narrative`, you'll see just how flexible it is. **matplotlib** is the foundation stone of quite a few other data visualisation libraries, such as **plotlnine** and **seaborn**, which shows how much it can do. It was famously used to create the first ever image of a black hole {cite:t}`akiyama2019first`.

It's worth saying that **matplotlib** has a very different philosophy than **seaborn** and **plotnine** though. Apart from being *imperative* rather than *declarative*, it also prefers unstacked data to tidy data: so, to create plots quickly, instead of every line you want to plot being stored in a single column called "country" as it would be using tidy data, it prefers *each* column to have a different country in.

For a more in-depth introduction to **matplotlib**, head over to [this tutorial](https://github.com/rougier/matplotlib-tutorial). This chapter is indebted to that tutorial, to the excellent [**matplotlib** documentation](https://matplotlib.org/stable/tutorials/index.html), and to the book *Scientific Visualization in Matplotlib* {cite:t}`rougier2021scientific`.

As ever, we'll start by importing some key packages and initialising any settings:

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Set seed for random numbers
seed_for_prng = 78557
prng = np.random.default_rng(seed_for_prng)  # prng=probabilistic random number generator

In [None]:
import matplotlib_inline.backend_inline

# Plot settings
plt.style.use(
    #"https://github.com/aeturrell/coding-for-economists/raw/main/plot_style.txt"
    "plot_style.txt"
)
matplotlib_inline.backend_inline.set_matplotlib_formats("svg")

# Set max rows displayed for readability
pd.set_option("display.max_rows", 6)
pd.set_option("display.max_columns", 7)

## Understanding Matplotlib

You'll see we imported `matplotlib.pyplot as plt` above; this is the main part of **matplotlib** that we'll use in practice.

### The **matplotlib** API

**matplotlib** {cite:t}`hunter2007matplotlib` has its origin in an attempt to replicate the plotting functionality of the paid programming language Matlab but has since outgrown these roots and we're going to show you how to use it in the most productive way. It is worth saying, though, that the **matplotlib** API (application programming interface) is huge so there's no way we'll be able to cover *everything* it can do.

We'l be using the 'object-oriented API' to use **matplotlib**, and this is the recommended way. This just means that we create objects, like *figures* and *axes*, that have state (they remember what you did to them).

The object-oriented API is most often used by creating two fundamental objects that are used by almost every chart: the figure and the axes. You should think of the figure object, `fig`, as the canvas on which you can put any number of charts. Each ax (short for 'axis') object is one chart within a figure. Of course, most of the time you're likely only to have one axis per figure, but in the cases when you don't it's a really useful setup. The plotting of elements such as lines, points, bars, and so, are controlled by the `ax` objects while the overall settings are controlled by `fig`.

The most simple chart we can think of would be a line plot. Let's see how to do that:

In [None]:
fig, ax = plt.subplots()  # Get the figure and one axis as a subplot
ax.plot(
    [1, 2, 3, 4, 5, 6],  # Add some data on the x and y axes
    [1, 4, 2, 3, 1, 7],
)


Notice that, unlike **seaborn**, **matplotlib** will happily accept raw data (but will work with dataframes too).

```{admonition} Tip
:class: tip
**Matplotlib** returns an object when used in certain contexts, eg it might return `[<matplotlib.lines.Line2D at...` above. To suppress this, end the command with a semi-colon, `;`, or call `plt.show()` as the last command.
```

Let's see an example of a scatter plot using this object-oriented approach. Note that we begin in the same way, with getting a figure and an axis. But now we're going to use `ax.scatter` instead, plus throw in a couple of extra settings (can you guess what they do?).

In [None]:
fig, ax = plt.subplots()        # Create a figure containing a single axes.
ax.scatter(
    [1, 2, 3, 4, 5, 6],  # Plot some data on the axes.
    [1, 4, 2, 3, 1, 7],
    s=150,
    c="b"
    );

`s=150` sets the area of the points (ie the size of each marker), and `c='b'` sets the colour. Many of these features will accept an *array* of values (like `s = [1, 2, 3, 4, 5, 6]`) instead of a single value and will map them into the plot in the way you'd expect. 

Let's see that in practice. For the sizes, we'll linearly increment the points between 300 and 2000 in 6 steps. For the colours, we'll just type out six distinct values.

In [None]:
fig, ax = plt.subplots()
ax.scatter(
    [1, 2, 3, 4, 5, 6],
    [1, 4, 2, 3, 1, 7],
    s=np.linspace(300, 2000, 6),
    c=["b", "r", "g", "k", "cyan", "yellow"],
    edgecolors="k",
    alpha=0.5,
);

We also asked for partly transparent points via the `alpha` setting (the default is `alpha=1`, which is a solid colour), and a black (`color='k'`) line-edge colour.

As ever, you can call help on a specific command in order to understand what options (keyword arguments) it accepts. And, in Visual Studio Code, you can just hover over the function or method name.

Here is what `help` returns when you run it on the `scatter` method:

In [None]:
help(ax.scatter)

There are a huge number of options for what to put on axes; the table below gives a guide to the most essential ones for 2D plots.

| Code | What it does |
|---|---|
| `Axes.plot` | Plot y versus x as lines and/or markers. |
| `Axes.errorbar` | Plot y versus x as lines and/or markers with attached errorbars. |
| `Axes.scatter` | A scatter plot of y vs x |
| `Axes.step` | Make a step plot. |
| `Axes.loglog` | Make a plot with log scaling on both the x and y axis. |
| `Axes.semilogx` | Make a plot with log scaling on the x axis. |
| `Axes.semilogy` | Make a plot with log scaling on the y axis. |
| `Axes.fill_between` | Fill the area between two horizontal curves. |
| `Axes.fill_betweenx` | Fill the area between two vertical curves. |
| `Axes.bar` | Make a bar plot. |
| `Axes.barh` | Make a horizontal bar plot. |
| `Axes.bar_label` | Label a bar plot. |
| `Axes.stem` | Create a stem plot. |
| `Axes.eventplot` | Plot identical parallel lines at the given positions. |
| `Axes.pie` | Plot a pie chart. |
| `Axes.stackplot` | Draw a stacked area plot. |
| `Axes.broken_barh` | Plot a horizontal sequence of rectangles. |
| `Axes.vlines` | Plot vertical lines at each x from ymin to ymax. |
| `Axes.hlines` | Plot horizontal lines at each y from xmin to xmax. |
| `Axes.axhline` | Add a horizontal line across the Axes. |
| `Axes.axhspan` | Add a horizontal span (rectangle) across the Axes. |
| `Axes.axvline` | Add a vertical line across the Axes. |
| `Axes.axvspan` | Add a vertical span (rectangle) across the Axes. |
| `Axes.axline` | Add an infinitely long straight line. |
| `Axes.fill` | Plot filled polygons. |
| `Axes.boxplot` | Draw a box and whisker plot. |
| `Axes.violinplot` | Make a violin plot. |
| `Axes.violin` | Drawing function for violin plots. |
| `Axes.bxp` | Drawing function for box and whisker plots. |
| `Axes.hexbin` | Make a 2D hexagonal binning plot of points x, y. |
| `Axes.hist` | Compute and plot a histogram. |
| `Axes.hist2d` | Make a 2D histogram plot. |
| `Axes.stairs` | A stepwise constant function as a line with bounding edges or a filled plot. |
| `Axes.contour` | Plot contour lines. |
| `Axes.contourf` | Plot filled contours. |

Now, it might not seem like it, but you already have the makings of a wide range of **matplotlib** plots. You just need to follow this recipe:

1. `fig, ax = plt.subplots()`
2. Choose what you want to put on your axis, for example `axes[0].scatter` or `axes[0].plot`
3. Put data into the method you chose in 2 in the required format (and remember you can check the documentation to get the right format; either using `help`, hovering over the method name in Visual Studio Code, or heading to the **matplotlib** documentation, where there are often also examples to look at---here are the [examples for hex bins](https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.hexbin.html#examples-using-matplotlib-axes-axes-hexbin)).


### Anatomy of a **matplotlib** graph

So, you've made a nice plot and now you want to tweak it. Well, **matplotlib** certainly has a LOT of choice when it comes to tweaking! Rather than go through all of the many, many options for customisation, the figure below (from the documentation) gives an overview of the options:

![Anatomy of a matplotlib figure](https://matplotlib.org/_images/anatomy.png)

Let's run through a few of the most important plot elements:

- **Figure**, or 'fig': the figure keeps track of all the child Axes that are on it and a smattering of 'special' artists (titles, figure legends, etc). A figure can contain any number of Axes, but will typically have at least one if it is to be interesting!

- **Axes**, or 'ax': this is the plot, the region of the image that traces out the data. A given Figure can contain many Axes, but a given Axes object can only be in one Figure. The Axes contains two (or three in the case of 3D) Axis objects (Axes and Axis are different things!) that record the data limits (you can override these via the `axes.Axes.set_xlim()` and `axes.Axes.set_ylim()` methods). Each Axes object has a title (set via `set_title()`), an x-label (set via `set_xlabel()`), and a y-label set via `set_ylabel()`). When you add, say, a line chart to an Axes object it appears as a Line2D object associated that that axis and it is created by calling a method on an Axes object.

- **Axis**: these are the number-line objects that control the limits of what the viewer can see. They also provide the means to access the ticks (the marks on the axis) and ticklabels (strings labeling the ticks). The location of the ticks is determined by a Locator object and the ticklabel strings are formatted by a Formatter. The combination of Locator and Formatter gives very fine control over the tick locations and labels.

- **Markers**: these are what are produced by scatter plots.

- **Labels and titles**: text to help the viewers of the chart make sense of what they're seeing.

- **Legend**: if needed, contains the key to understand the shapes or colours used in lines or markers or bars (or ...) that are on the chart.


## Customising Charts

While the **matplotlib** API is too extensive for us to cover every detail, it's important to know about some standard customisations that you might need to get going.

### Limits and Labels

We've only seen aspects of the plot that are customisable through the `scatter` keyword so far; let's now see an example that's a bit more real (and useful!) in which we'll want to add labels, a title, and more. Note that **matplotlib** supports LaTeX equations in text.

We'll plot some data from the US Midwest demographics dataset.


In [None]:
df = pd.read_csv(
    "https://vincentarelbundock.github.io/Rdatasets/csv/ggplot2/midwest.csv",
    index_col="PID",
).drop("Unnamed: 0", axis=1)
df.head()

Now, as well as a scatter, let's add some context to the chart:

In [None]:
fig, ax = plt.subplots()
ax.scatter(df["area"], df["poptotal"], edgecolors="k", alpha=0.6)  # Make a scatter plot on "ax"
ax.set_xlim(0, 0.1)  # Set the limits on the x-axis
ax.set_ylim(0, 1e6)  # Set the limits on the y-axis
ax.set_xlabel("Area")  # Set the x label
ax.set_ylabel("Population")  # Set the y label
ax.set_title("Area vs. Population", loc="right");  # Add a title and say where it should go

We're not quite done with titles. Often it's useful to have a y-axis title that is horizontal, and so easier to read. And, additionally, it's good practice to have a title that tells the viewer what they should take away from the graph. To achieve both of these together, you can i) use `plt.suptitle` for a figure-level title whose position can be fine-tuned using `x` and `y` keyword arguments; and ii) use `ax.set_title` to provide the y-axis label. The chart below demonstrates this.

In [None]:
fig, ax = plt.subplots()
ax.scatter(df["area"], df["poptotal"]/1e6, edgecolors="k", alpha=0.6, s=150)
ax.set_ylim(0, None)
ax.set_xlim(0, None)
ax.set_xlabel("Area")
ax.set_title("Population (millions)", loc="left", fontsize=14)
plt.suptitle("Little correlation between population and area", y=1.02, x=0.45);

### Customising Axes

Now we have the makings of a basic chart, with axes labels and even a title. 

But what we know from the data is that "Area" is actually a percentage. We could represent this by doing

```python
ax.set_xlabel("Area, %")
```

but as **matplotlib** is infinitely customisable, there is another option too---changing the tick labels.

On the x-axis, we'll add a percentage suffix on the numbers plus some minor tick marks.

In [None]:
from matplotlib.ticker import AutoMinorLocator

fig, ax = plt.subplots()
ax.scatter(df["area"], df["poptotal"], edgecolors="k", alpha=0.6)
ax.set_xlim(0, 0.1)
ax.set_ylim(0, 1e6)
ax.set_xlabel("Area")
ax.set_ylabel("Population")
ax.set_title("Area vs. Population", loc="right")
ax.xaxis.set_minor_locator(AutoMinorLocator(4))  # Add minor tick marks, four between every major one
ax.xaxis.set_major_formatter("{x:.2f}%");  # Every x value has 2 decimal places and is followed by a '%' sign

The `AutoMinorLocator(4)` inserts 4 minor tick marks between each major tick mark. The major formatter is `'{x:.2f}%'`, which says print 2 decimal places followed by a % sign.

### Using Loops to Achieve Customisation Effects

Let's say we want to now differentiate these points with colour according to which state they belong to and add a legend that says which states have which colour. The easiest way to do this is by creating a `for` loop.

We're going to loop over `state`. We do it using `for state in df["state"].unique():`, which runs over states. The different colours are provided by passing `cmap="colourmap-name"`. This generates a different colour for each state and cycles to the next colour each time `ax.scatter` is called in the loop.

Now we have colour capturing a new dimension of the data (state), we also need to have a legend that shows which colour each state corresponds to. By passing `lable=state` within the `for` loop, we build up state<==>colour equivalences that appears on the chart when we call `ax.legend`.

In [None]:
fig, ax = plt.subplots()
for state in df["state"].unique():
    xf = df.loc[df["state"] == state]
    ax.scatter(
        xf["area"],
        xf["poptotal"],
        cmap="Dark2",  # Note that we have qualitative data so we use a colour map
        label=state,
        s=100,
        edgecolor="k",
        alpha=0.8,
    )
ax.set_xlim(0, 0.1)
ax.set_ylim(0, 1e6)
ax.set_xlabel("Area")
ax.set_ylabel("Population")
ax.set_title("Area vs. Population", loc="center")
ax.xaxis.set_minor_locator(AutoMinorLocator(4))
ax.xaxis.set_major_formatter("{x:.2f}%")
ax.legend(title="State", loc="upper right");

We used a colourmap to get 5 qualitatively different colours; there are also sequential colormaps for continuous (as opposed to discrete) variables. You can find out more about the colormaps available in base **matplotlib** [here](https://matplotlib.org/3.1.0/tutorials/colors/colormaps.html).

### Adding Text Annotations

It's possible to point out specific values with a text label. Adding extra information like this can be useful in all kinds of circumstances; for example, showing the biggest or smallest, drawing attention to a particular story, or simply flagging a special value.

Let's add a couple of text annotations. Let's say we want to annotate the county with the biggest area, and the county with the highest population. For the biggest area, we'll just pop the label next to the point. First we need to *find* the position of the point:

In [None]:
max_area_row = df.iloc[df["area"].argmax()]
max_area_row

Now we use `ax.annotate` to add this information.

In [None]:
fig, ax = plt.subplots()
for state in df["state"].unique():
    xf = df.loc[df["state"] == state]
    ax.scatter(
        xf["area"],
        xf["poptotal"],
        cmap="Dark2",  # Note that we have qualitative data so we use a colour map
        label=state,
        s=100,
        edgecolor="k",
        alpha=0.8,
    )
ax.set_xlim(0, 0.12)
ax.set_ylim(0, 1e6)
ax.set_xlabel("Area")
ax.set_ylabel("Population")
ax.set_title("Area vs. Population", loc="center")
ax.legend()
ax.annotate(
    text=f'Max. area: {max_area_row.loc["county"].title()}',
    xy=tuple(max_area_row[["area", "poptotal"]]),
    );

What if we want to put text somewhere *other* than right next to the datapoint? We can do that and have an arrow to connect the label to the data point.

In [None]:
max_pop_row = df.iloc[df["poptotal"].argmax()]

fig, ax = plt.subplots()
for state in df["state"].unique():
    xf = df.loc[df["state"] == state]
    ax.scatter(
        xf["area"],
        xf["poptotal"],
        cmap="Dark2",  # Note that we have qualitative data so we use a colour map
        label=state,
        s=100,
        edgecolor="k",
        alpha=0.8,
    )
ax.set_xlim(0, 0.12)
ax.set_ylim(0, 6e6)
ax.set_xlabel("Area")
ax.set_ylabel("Population")
ax.set_title("Area vs. Population", loc="center")
ax.legend()
ax.annotate(
    text=f'Max. area: {max_area_row.loc["county"].title()}',
    xy=tuple(max_area_row[["area", "poptotal"]]),
    xytext=(-100, 20),
    textcoords="offset points",
    arrowprops=dict(arrowstyle="->", connectionstyle="angle3")
    )
ax.annotate(
    text=f'Max. pop: {max_pop_row["county"].title()}',
    xy=tuple(max_pop_row[["area", "poptotal"]]),
    xytext=(-100, -50),
    textcoords="offset points",
    arrowprops=dict(arrowstyle="->", connectionstyle="angle3")
);

You can get really creative with text annotations---as the image below, taken from [*Scientific Visualisation: Python + Matplotlib*](https://github.com/rougier/scientific-visualization-book) {cite:t}`rougier2021scientific` shows. You can learn more about text annotations [here](https://matplotlib.org/stable/tutorials/text/annotations.html).

In [None]:
import matplotlib.patheffects as path_effects

fig, axes = plt.subplots(1, 2, figsize=(6, 6))
plt.setp(axes, xlim=[-1, +1], xticks=[], ylim=[-1, +1], yticks=[], aspect=1)

X = prng.normal(0, 0.35, 1000)
Y = prng.normal(0, 0.35, 1000)

I = prng.choice(len(X), size=5, replace=False)
Px, Py = X[I], Y[I]
I = np.argsort(Y[I])[::-1]
Px, Py = Px[I], Py[I]

for ax in axes:
    ax.scatter(X, Y, edgecolor="None", facecolor="C1", alpha=0.5)
    ax.scatter(Px, Py, edgecolor="black", facecolor="white", zorder=20)
    ax.scatter(Px, Py, edgecolor="None", facecolor="C1", alpha=0.5, zorder=30)

# first subplot
y, dy = 1.0, 0.125
style = "arc,angleA=-0,angleB=0,armA=-100,armB=0,rad=0"

for i in range(len(I)):
    text = axes[0].annotate(
        "Point " + chr(ord("A") + i),
        xy=(Px[i], Py[i]),
        xycoords="data",
        xytext=(0, 30),
        textcoords="offset points",
        ha="center",
        size="small",
        arrowprops=dict(
            arrowstyle="->", shrinkA=0, shrinkB=5, color="black", linewidth=0.75
        ),
    )
    text.set_path_effects(
        [path_effects.Stroke(linewidth=2, foreground="white"), path_effects.Normal()]
    )
    text.arrow_patch.set_path_effects(
        [path_effects.Stroke(linewidth=2, foreground="white"), path_effects.Normal()]
    )

# 2nd subplot
for i in range(len(I)):
    text = axes[1].annotate(
        "Point " + chr(ord("A") + i),
        xy=(Px[i], Py[i]),
        xycoords="data",
        xytext=(1.25, y - i * dy),
        textcoords="data",
        arrowprops=dict(
            arrowstyle="->",
            color="black",
            linewidth=0.75,
            shrinkA=20,
            shrinkB=5,
            patchA=None,
            patchB=None,
            connectionstyle=style,
        ),
    )
    text.arrow_patch.set_path_effects(
        [path_effects.Stroke(linewidth=2, foreground="white"), path_effects.Normal()]
    )
plt.tight_layout();

### Special Lines

Next, we're going to add some special lines to our chart. This is surprisingly useful in practice, and there are a few commands. If we just want a horizontal or vertical line, our best bet is `ax.axhline` and `ax.axvline` respectively.

Let's add a special line to our example to show where the *mean area* of counties appears. First we need the mean area:

In [None]:
mean_county_area = df["area"].mean()
mean_county_area

Now we're going to add the line, using `ax.axvline`, but also a corresponding annotation that tells viewers what this line is showing. 

In [None]:

fig, ax = plt.subplots()
for state in df["state"].unique():
    xf = df.loc[df["state"] == state]
    ax.scatter(
        xf["area"],
        xf["poptotal"],
        cmap="Dark2",  # Note that we have qualitative data so we use a colour map
        label=state,
        s=100,
        edgecolor="k",
        alpha=0.8,
    )
ax.set_xlim(0, 0.12)
ax.set_ylim(0, 6e6)
ax.set_xlabel("Area")
ax.set_ylabel("Population")
ax.set_title("Area vs. Population", loc="center")
ax.legend()
ax.annotate(
    text=f'Max. area: {max_area_row.loc["county"].title()}',
    xy=tuple(max_area_row[["area", "poptotal"]]),
    xytext=(-100, 20),
    textcoords="offset points",
    arrowprops=dict(arrowstyle="->", connectionstyle="angle3")
    )
ax.annotate(
    text=f'Max. pop: {max_pop_row["county"].title()}',
    xy=tuple(max_pop_row[["area", "poptotal"]]),
    xytext=(20, -50),
    textcoords="offset points",
    arrowprops=dict(arrowstyle="->", connectionstyle="angle3")
)
ax.axvline(x=mean_county_area, linewidth=0.5, linestyle='-.')
ax.annotate("Mean county area",
            xy=(mean_county_area, 0.5),
            xycoords=("data", "axes fraction"),
            rotation=-90,
            fontsize=11);

Note that the second co-ordinate of the text annotation is in terms of fraction of the figure's y-axis rather than in the co-ordinates of the data. This is useful because now, whatever else changes in the chart, we know the text will appear half-way up the axis. There are several different co-ordinate systems that **matplotlib** accepts depending on what you're trying to achieve:

- the co-ordinates of the *data*, eg area and population in the chart above
- the fraction of the figure axes, aka the *"axes fraction"*
- *offset points*, relative to another point (used for text relative to a data point)
- various options using *pixels*
- and more, including *polar* co-ordinates

You can find out a bit more about the different co-ordinate systems on the [**matplotlib** documentation](https://matplotlib.org/stable/tutorials/advanced/transforms_tutorial.html).

## Multiple Charts on one Figure

One basic bit of functionality that you might need is to put more than one type of information on a single overall figure. This could mean having multiple subplots or it could be about putting more information on a single set of axes.

### Mutiples features on one set of axes

This is the kind of task where **matplotlib**'s build-what-you-want philosophy starts to win out. To add another feature on an `ax` that you've already created is as simple as calling `ax.<method>` again. You can add as many features as you like.

In the example below, we'll call `ax.hist` followed by `ax.plot` to get a the theoretical curve for a normal distribution (aka Gaussian) overlaid on a kernel density estimate based on many draws from the relevant distribution using **numpy**.

While we're at it, let's add an equation that describes the theoretical curve.

In [None]:
rand_draws = prng.standard_normal(5000)
grid_x = np.linspace(-5, 5, 1000)

fig, ax = plt.subplots()
ax.hist(rand_draws, bins=50, density=True, label="Data")
ax.plot(
    grid_x,
    1 / np.sqrt(2 * np.pi) * np.exp(-(grid_x ** 2) / 2),
    linewidth=4,
    label=r"$\frac{1}{2\pi}e^{-\frac{x^2}{2}}$"
    )
ax.legend(fontsize=14);

### Facets

Another way to put more information on a single figure is to have different *facets*. To illustrate some of the ideas we're about to see, it's going to be useful to have some data. Let's pull down GDP per capita for a selection of countries.

In [None]:
from pandas_datareader import wb

start_year = 2000
end_year = 2022

df = wb.download(
    indicator=["NY.GDP.PCAP.KD"],                  # GDP per capita in 2015 USD
    country=["US", "GB", "FR", "DE", "IT", "JP"],
    start=start_year,
    end=end_year,
)
df = df.reset_index()                              # drop country as index (for plots)
df["year"] = df["year"].astype("int")              # ensure year is a number
# index to that country's value in 2000
df["GDP per capita"] = df.groupby("country")["NY.GDP.PCAP.KD"].transform(lambda x: 100*x/x.min())
df.head()

#### Subplots

The first way to put data on multiple charts that are part of the same overall figure is to use the built-in subplot function. `plt.subplots()` accepts arguments for `nrows=` and `ncols=` that we can specify the number of figures with. We have six countries so let's do two rows of three columns.

We can use a `for` loop to go over these. Actually, though, the structure of the `axes` object that comes back from `plt.subplots` is a 2x3 matrix (or array), and you can't loop over that. But we can loop over a "flattened" version of it, which you can create with `axes.flatten()`.

Once we're in the loop, we subset the data by country for each loop and add it to the chart using `ax.plot`. Finally, we'll use the same limits for every subplot to make the chart easier to read.

Let's see this with the data we just pulled down.

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(10, 6))
for i, ax in enumerate(axes.flatten()):
    country = df["country"].unique()[i]
    country_df = df.loc[df["country"] == country, :]
    ax.plot(country_df["year"], country_df["GDP per capita"],
            lw=3, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][i])
    ax.set_title(country, loc="center", fontsize=13)
    ax.yaxis.tick_right()
    ax.spines["right"].set_visible(True)
    ax.spines["left"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.set_xlim(start_year, end_year)
    ax.set_ylim(df["GDP per capita"].min(), df["GDP per capita"].max())
fig.suptitle(f"GDP per capita (indexed to 100 in {start_year})", fontsize=15)
plt.tight_layout();

A nice extra you can do with this sort of chart is to add the other countries but greyed out so that it's clear which country is featured but the cross-country comparison is a bit easier.

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(10, 6))
for i, ax in enumerate(axes.flatten()):
    country = df["country"].unique()[i]
    # grab the other countries
    other_countries = [x for x in df["country"].unique() if x!=country]
    for other in other_countries:
        o_country_df = df.loc[df["country"] == other, :]
        ax.plot(o_country_df["year"], o_country_df["GDP per capita"],
            lw=1, color="k", alpha=0.1)
    country_df = df.loc[df["country"] == country, :]
    ax.plot(country_df["year"], country_df["GDP per capita"],
            lw=3, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][i])
    ax.set_title(country, loc="center", fontsize=13)
    ax.yaxis.tick_right()
    ax.spines["right"].set_visible(True)
    ax.spines["left"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.set_xlim(start_year, end_year)
    ax.set_ylim(df["GDP per capita"].min(), df["GDP per capita"].max())
fig.suptitle(f"GDP per capita (indexed to 100 in {start_year})", fontsize=15)
plt.tight_layout();

#### Gridspec

This is a more complicated way of creating multiple subplots, but it is more flexible too: although it still relies on an underlying grid structure, the units of that grid don't have to all equate to a single subplot. It's easier to see than to describe!

Here's an example that shows what's possible:

In [None]:
import matplotlib.gridspec as gridspec

p = plt.rcParams
p["figure.figsize"] = 10, 7
p["font.sans-serif"] = ["Roboto Condensed"]
p["font.weight"] = "light"
p["ytick.minor.visible"] = True
p["xtick.minor.visible"] = True
p["axes.grid"] = True
p["grid.color"] = "0.5"
p["grid.linewidth"] = 0.2


X = np.linspace(-np.pi, np.pi, 257, endpoint=True)
C, S = np.cos(X), np.sin(X)

fig = plt.figure(constrained_layout=True)
nrows, ncols = 3, 3
gspec = gridspec.GridSpec(ncols=ncols, nrows=nrows, figure=fig)


def plot(ax, text):
    ax.set_xlim(0, 1)
    ax.set_xticks(np.linspace(0, 1, 5))
    ax.set_xlabel("X Label", fontsize=10)
    ax.set_ylim(0, 1)
    ax.set_yticks(np.linspace(0, 1, 5))
    ax.set_ylabel("Y Label", fontsize=10)
    ax.text(
        0.5, 0.5, text, alpha=0.75, ha="center", va="center", weight="bold", size=12
    )
    ax.set_title("Title", family="Roboto", weight=500, fontsize=12)


for i in range(ncols):
    plot(plt.subplot(gspec[0, i]), "subplot(gspec[0,%d])" % i)
for i in range(1, nrows):
    plot(plt.subplot(gspec[i, 0]), "subplot(gspec[%d,0])" % i)
plot(plt.subplot(gspec[1:, 1:]), "subplot(gspec[1:,1:])");

Let's say we wanted to really feature a single country more than the others: we might want to give it a full height spot. We can do this with `gridspec`.

Taking our example, let's "focus" on the UK.

GRIDSPEC REALLY NEEDED?

In [None]:
# import matplotlib.gridspec as gridspec

# fig = plt.figure(constrained_layout=True)
# nrows, ncols = 3, 3
# gspec = gridspec.GridSpec(ncols=ncols, nrows=nrows, figure=fig)
# # create a list of the five spots we don't want to feature UK on
# non_uk_spots = gspec[:1, :1]
# axes = [[fig.add_subplot(gspec[i, j]) for i in [0, j]] for j in [0, 1]]
# o_countries_df = df.loc[df["country"]!="United Kingdom", :]
# o_country_names = o_country_df["country"].unique()
# for i, g_now in enumerate(non_uk_spots):
#     # grab the other country's data
#     o_country = o_country_names[i]
#     o_country_df = o_countries_df.loc[o_countries_df["country"]==o_country, :]
#     ax = plt.subplot(g_now)
#     ax.plot(o_country_df["year"], o_country_df["GDP per capita"],
#             lw=1, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][i])
#     ax.set_title(country, loc="center", fontsize=13)
#     ax.yaxis.tick_right()
#     ax.spines["right"].set_visible(True)
#     ax.spines["left"].set_visible(False)
#     ax.spines["top"].set_visible(False)
#     ax.set_xlim(start_year, end_year)
#     ax.set_ylim(df["GDP per capita"].min(), df["GDP per capita"].max())
# # fig.suptitle(f"GDP per capita (indexed to 100 in {start_year})", fontsize=15)
# # plt.tight_layout();
# plt.show();

In [None]:
[[(i, j) for i in [0, 1, 2] if i+j<=2] for j in [0, 1, 2]]

## A Tour of **matplotlib** plots

Let's see a couple of other plot types that are useful, beginning with contour plots:

### More Unusual Charts

In [None]:
def f(x, y):
    return np.sin(x) ** 10 + np.cos(10 + y * x) * np.cos(x)


x = np.linspace(0, 5, 100)
y = np.linspace(0, 5, 100)

X, Y = np.meshgrid(x, y)
Z = f(X, Y)

fig, ax = plt.subplots()
cf = ax.contourf(X, Y, Z, cmap="plasma")
ax.set_title(r"$f(x,y) = \sin^{10}(x) + \cos(x)\cos\left(10 + y\cdot x\right)$")
cbar = fig.colorbar(cf)

This demonstrates creating a heatmap (or contour plot) with a colour bar legend and a title that's rendered with latex. The heatmap uses a perceptually uniform distribution that makes equal changes look equal; **matplotlib** has a few of these. If you need more colours, check out the packages [**colorcet**](https://colorcet.holoviz.org/) and [**palettable**](https://jiffyclub.github.io/palettable/).

### Multiple Charts on One Axis

You can do some really quite amazing things using **matplotlib**'s build what you want philosophy. For instance, you can [build a timeline](https://matplotlib.org/3.1.1/gallery/lines_bars_and_markers/timeline.html#sphx-glr-gallery-lines-bars-and-markers-timeline-py), make [charts with polar axes](https://matplotlib.org/3.1.1/gallery/pie_and_polar_charts/polar_bar.html#sphx-glr-gallery-pie-and-polar-charts-polar-bar-py), and create [XKCD style plots](https://matplotlib.org/3.1.1/gallery/showcase/xkcd.html#sphx-glr-gallery-showcase-xkcd-py).

One basic bit of functionality that you might need is to put more than one type of information on a single plot. Using the object-oriented API, this is as simple as calling another method on an `ax` that you've already created. In the example below, we'll call `ax.hist` followed by `ax.plot` to get a the theoretical curve for a normal distribution (aka Gaussian) overlaid on a kernel density estimate based on many draws from the relevant distribution using `numpy`.

In [None]:
rand_draws = prng.standard_normal(5000)
grid_x = np.linspace(-5, 5, 1000)

fig, ax = plt.subplots()
ax.hist(rand_draws, bins=50, density=True)
ax.plot(grid_x, 1 / np.sqrt(2 * np.pi) * np.exp(-(grid_x ** 2) / 2), linewidth=4);

Another important class of plots that we should look at are line charts, especially time series. Let's grab the real GDP time series for the UK and US and look at the year-on-year quarterly growth rates (the first few entries will be `NaN` because we switched to growth space).

In [None]:
import pandas_datareader.data as web

ts_start_date = pd.to_datetime("1999-01-01")

df = pd.concat(
    [
        web.DataReader("ticker=RGDP" + x, "econdb", start=ts_start_date)
        for x in ["US", "UK"]
    ],
    axis=1,
)
df.columns = ["US", "UK"]
df.index.name = "Date"
df = 100 * df.pct_change(4)
df.head()

Okay, now the quick way to plot this would be to simply call `df.plot()`. That works fine and doesn't look too bad: it recognises that it's dealing with a datetime and plots sensible tick labels on the x-axis, and it produces a legend for each of the two time series. But let's say we want to make that plot quickly *and* do some fine-tuning with **matplotlib**: how can we? The answer is to create a `fig` and an `ax` and then to ask our dataframe to use the `ax` we already created to plot the data. We can then carry on using the `ax` object in any way we like.

In [None]:
fig, ax = plt.subplots()
df.plot(ax=ax, lw=3)
ax.set_title("Real GDP growth, %", loc="right")
ax.spines["right"].set_visible(True)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.yaxis.tick_right()  # Put tick marks and tick labels on right-hand side

We could also have created this chart by looping through the different countries just as we looped through the difference states in the previous example.

Let's also see what happens when we use different limits:

In [None]:
fig, ax = plt.subplots()
df.plot(ax=ax, lw=3)
ax.set_title("Real GDP growth, %", loc="right")
ax.spines["right"].set_visible(True)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.yaxis.tick_right()  # Put tick marks and tick labels on right-hand side
ax.set_xlim(pd.to_datetime("2018-01-01"), pd.to_datetime("2021-01-01"));

As you can see, the plot dynamically responded to the shorter time period by putting more details in, here of quarters. You can specify exactly what you want with the tick label formatters that cater to datetimes, but the defaults are pretty well-behaved.

Sometimes, you want to show different lines in different panels (or facets, as they are also known). You can do that with **matplotlib** too, though once again it's a build-your-own affair.

Let's put these two time series in their own facets to see how it works.

In [None]:
fig, axes = plt.subplots(1, 2)
for i, ax in enumerate(axes):
    ax.plot(df.index, df.iloc[:, i], lw=3)
    ax.set_title(df.columns[i], loc="left")
    ax.yaxis.tick_right()
    ax.spines["right"].set_visible(True)
    ax.spines["left"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.set_ylim(df.min().min(), df.max().max())
fig.suptitle("Real GDP growth, %")
plt.tight_layout();

Quite a few things are happening here. The first is that we asked for more rows and columns from matplotlib using `plt.subplots(nrows, ncolumns)` and instead of returning a single `ax` it returned a list of axes (here of length 2). We then iterated over those axes using enumerate, which gives an integer `i` and an `ax` from the list, and plotted a column on each axis by subsetting the dataframe one column at a time using `df.iloc[:, i]`. Remember that the first position in `iloc` is for the index values (ie the rows), while the second is for the columns. Again, this was quite a lot of code to get a relatively simple chart done! And that's exactly why we're now going to take a look at the declarative library **seaborn** that wraps **matplotlib**.

## Some Useful Charts for Economics

## Advanced Topics

### Creating a style that applies to all charts



## Fun Charts