# Exercise 1.4 - Multiple plots per figure
prepared by M.Hauser

Until now we used only one axes per figure. However you may want to add several subplots in the same figure. 
As mentioned earlier, it is possible to have more than one axes per figure in matplotlib.

In [None]:
import matplotlib.pyplot as plt

## `plt.subplots` - simple grids

Until now, we used `plt.subplots()` (notice the `s` at the end) to create a single axes, but it supports the `nrows` and `ncol` keywords to create regular grids.



In [None]:
f, axs = plt.subplots(2, 3)

## `hspace` and `wspace`

This looks ok, but there is not enough space between the axes for the labels. You can adjust the space between the axes manually with `hspace` and `wspace` (`h` and `w` stands for height and width, respectively). Note that `hspace` and `wspace` are properties of the GridSpec (discussed later), so we have to pass it as `gridspec_kw`.

 * `wspace` the amount of width reserved for space between subplots, expressed as a fraction of the average axis width, default value = 0.2
 * `hspace`: the amount of height reserved for space between subplots, expressed as a fraction of the average axis height, default value = 0.2

(I remember this as "w(idth)space" and "h(eight)space".)

In [None]:
f, axs = plt.subplots(2, 3, gridspec_kw=dict(hspace=0.3, wspace=0.5))

Note that `hspace` and `wspace` are relative measures:

In [None]:
f, axs = plt.subplots(4, 3, gridspec_kw=dict(hspace=0.3, wspace=0.5))

Alternatively, you can tell matplotlib to adjust the distances automatically. However, this may not always work (or may not work as you wish).

In [None]:
f, axs = plt.subplots(2, 3, layout="constrained")

The axes can also share the x and y axis. This automatically removes the inner x and y ticks. 

Note that when sharing the axis, setting the x or y limit for one axes adjusts them for all the axes.

In [None]:
f, axs = plt.subplots(2, 3, sharex=True, sharey=True)

ax = axs[0, 0]

ax.set_xlim(-0.5, 0.5)

Note the x-axis.



If you have more than one column and more than one row of axes, `axs` is a 2D array. We use the abbreviation `axs` instead of `axes` to differentiate single Axes instances from a collection of Axes instances.

In [None]:
print(axs.shape)

axs

Let's see the order of the axes in the array:

In [None]:
f, axs = plt.subplots(2, 3, sharex=True, sharey=True)

# axes are in a two-dimensional array, indexed by [row, col]
for i in range(2):
    for j in range(3):
        axs[i, j].text(0.5, 0.5, f"({i}, {j})", fontsize=16, ha="center", va="center")

Because I don't find it very convenient to address the axes as `axs[i, j]` I often `flatten` them to a 1D array.

In [None]:
f, axs = plt.subplots(2, 3, sharex=True, sharey=True)

axs = axs.flatten()

# axs are now in a one-dimensional array, indexed by [index]
for i, ax in enumerate(axs):
    ax.text(0.5, 0.5, f"({i})", fontsize=16, ha="center", va="center")

## Adding axes by hand (`plt.axes`)

Axes can also be added by `plt.axes`. 

```python
ax = plt.axes()
```

Using `ax = plt.axes()` creates the same axes object as `f, ax = plt.subplots(1, 1)`, but it does not return the figure handle (`f`). 

```python
ax = plt.axes(rect)
```

With `ax = plt.axes(rect)` you can also specify the position of the new axes. The argument `rect` needs to have the form `rect = [left, bottom, width, height]`. The positions are given in the figure coordinate system, which ranges from 0 at the bottom left of the figure to 1 at the top right of the figure.

For example, to create an inset in the top right corner, we could do the following.

In [None]:
ax1 = plt.axes()
ax2 = plt.axes([0.6, 0.6, 0.27, 0.25])

As mentioned in the presentation: `ax1` and `ax2` are instances of `plt.Axes`, but only `ax1` is a `plt.Subplot`.

In [None]:
for ax in [ax1, ax2]:
    print(ax)
    print(f" * {isinstance(ax1, plt.Axes)    = }")
    print(f" * {isinstance(ax1, plt.Subplot) = }")
    print()

## Arbitrary grids (`plt.Gridspec`)

If you need axes that are not all of the same size but are irregular, i.e. may span several rows and/or columns, you  can create them with `plt.GridSpec`. This is a 2-step process, you first need to create a grid by calling `grid = plt.GridSpec(nrows, ncols)` and then create the single axes with `plt.subplot`.


In [None]:
# create grid
grid = plt.GridSpec(3, 3)

ax1 = plt.subplot(grid[0, :])
ax2 = plt.subplot(grid[1:3, :2])
ax3 = plt.subplot(grid[1, 2])
ax4 = plt.subplot(grid[2, 2]);

Here you can specify `hspace` and `wspace` directly: 

In [None]:
# create grid
grid = plt.GridSpec(3, 3, hspace=0.5, wspace=0.3)

ax1 = plt.subplot(grid[0, :])
ax2 = plt.subplot(grid[1:3, :2])
ax3 = plt.subplot(grid[1, 2])
ax4 = plt.subplot(grid[2, 2]);

_constrained layout_ also works:

In [None]:
f = plt.figure(layout="constrained")

# create grid
grid = plt.GridSpec(3, 3, figure=f)

ax1 = plt.subplot(grid[0, :])
ax2 = plt.subplot(grid[1:3, :2])
ax3 = plt.subplot(grid[1, 2])
ax4 = plt.subplot(grid[2, 2])

## Exercises



### Load Data

We will again use the station data (Temperature & Precip) for Switzerland, but this time we will use the time series instead of the climatology.

The data is available from MeteoSwiss.

The data has already been [retrieved and postprocessed](../data/prepare_data_MCH.ipynb).

In [None]:
import xarray as xr

In [None]:
def load_mch(station, annual=True):
    fN = f"../data/MCH_HOM_{station}.nc"
    ds = xr.open_dataset(fN, drop_variables=["station", "station_long"])

    if annual:
        # create annual data
        ds = ds.groupby("time.year").mean("time")

    return ds


BAS = load_mch("BAS")
BER = load_mch("BER")
GSB = load_mch("GSB")
DAV = load_mch("DAV")

LUG_monthly = load_mch("LUG", False)
ENG_monthly = load_mch("ENG", False)

### Exercise

 * Create a grid of Axes with 4 rows and 1 column
 * Populate the axes with the annual mean temperature of four stations
   ```python
   ax.plot(BAS.year, BAS.Temperature)
   ```
 * Loop through the axes to add `"T (°C)"` as the ylabel.
 * Add the station names as titles (set `loc="left"`)

In [None]:
station_names = ("BAS", "BER", "GSB", "DAV")
# ====

# code here
# f, axs =

# use this code to loop through the axes and station_names
# for i, ax in enumerate(axs):
#    station_names[i]

### Solution

In [None]:
station_names = ("BAS", "BER", "GSB", "DAV")

# ====

f, axs = plt.subplots(4, 1, sharex=True, layout="constrained")

ax = axs[0]
ax.plot(BAS.year, BAS.Temperature)

ax = axs[1]
ax.plot(BER.year, BER.Temperature)

ax = axs[2]
ax.plot(GSB.year, GSB.Temperature)

ax = axs[3]
ax.plot(DAV.year, DAV.Temperature)

for i, ax in enumerate(axs):
    ax.set_ylabel("T (°C)")

    ax.set_title(station_names[i], loc="left")

### Exercise
Use the following code and
 * Manually add an axes in the top left corner
 * Add the data for the years 2000 to 2010 (`BAS_sel`)
 * Test what happens if you use `layout="constrained"`

In [None]:
f, ax = plt.subplots()

ax.plot(BAS.year, BAS.Temperature)

ax.set_ylim(None, 12.5)

BAS_sel = BAS.sel(year=slice(2000, 2010))

# code here


### Solution

In [None]:
f, ax = plt.subplots()

#f, ax = plt.subplots(layout="constrained")

ax.plot(BAS.year, BAS.Temperature)

ax.set_ylim(None, 12.5)

BAS_sel = BAS.sel(year=slice(2000, 2010))

# code here

ax2 = plt.axes([0.2, 0.65, 0.2, 0.20])

ax2.plot(BAS_sel.year, BAS_sel.Temperature)

### Exercise

 * Use `plt.GridSpec` to create two axes, one that takes ~ 3/4 of the width and the other ~ 1/4
 * Bonus
   * Plot precipitation data time series of BAS in the left axes
   * Plot a histogram of the same data in the right axes. Hint:
   ```python
   ax1.hist(BAS.Precipitation, bins=10, orientation='horizontal', density=True);
    ```

In [None]:
f = plt.figure(figsize=(20 / 2.54, 8 / 2.54))

# create grid
# grid =

### Solution

In [None]:
f = plt.figure(figsize=(20 / 2.54, 8 / 2.54))

# create grid
grid = plt.GridSpec(1, 4)

ax0 = plt.subplot(grid[0, :3])

ax1 = plt.subplot(grid[0, 3])

# plot data

ax0.plot(BAS.year, BAS.Precipitation)
ax1.hist(BAS.Precipitation, 10, orientation="horizontal", density=True)

# format plot
ax1.set_yticks([])

ax0.set_ylabel("Precipitation (mm)")

ax0.set_xlabel("Year")
ax1.set_xlabel("Probability [-]");

### Bonus Exercise

 * Try to understand the following code:

In [None]:
# get data
x = LUG_monthly.Temperature
y = ENG_monthly.Temperature

f = plt.figure(figsize=(15 / 2.54, 15 / 2.54))

grid = plt.GridSpec(4, 4, hspace=0.5, wspace=0.5)

ax_main = plt.subplot(grid[1:, :-1])


ax_y = plt.subplot(grid[1:, -1], xticklabels=[], sharey=ax_main)
ax_x = plt.subplot(grid[0, :-1], yticklabels=[], sharex=ax_main)

ax_main.plot(x, y, ".")

ax_x.hist(x, 30, histtype="stepfilled", orientation="vertical", color="0.7")
ax_y.hist(y, 30, histtype="stepfilled", orientation="horizontal", color="0.7")

ax_main.set_xlabel("T Lugano (°C)")
ax_main.set_ylabel("T Engelberg (°C)");