In [None]:
%matplotlib inline
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt

sa2 = gpd.read_file("data/sa2-wellington.gpkg")

# Small multiple maps
There are many ways you can do this. The example below uses an approach I've settled on over time.

In [None]:
nrows, ncols = 2, 2
figure, axes = plt.subplots(figsize = (12, 12), 
                            nrows = nrows, ncols = ncols, 
                            layout = "constrained")
schemes = ["equalinterval", "quantiles", "prettybreaks", "fisherjenks"]

for i, scheme in enumerate(schemes):
    ax = axes[i // ncols, i % ncols]
    sa2.plot(ax = ax, column = "pop_density", cmap = "Reds", 
             scheme = scheme, k = 7,
             ec = "k", lw = 0.5, legend = True, 
             legend_kwds = {"title": "Pop density per sq km", 
                            "loc": "upper left"})
    ax.set_axis_off()
    ax.set_title(f"{scheme = }")
plt.show()

Or a different example

In [None]:
sa1 = gpd.read_file("data/sa1-wellington.gpkg")
ages = pd.read_csv("data/welly-ages-final.csv")
ages.sa1_code = ages.sa1_code.astype(str)
sa1_ages = sa1.merge(
    ages, left_on = "SA12023_V1_00", right_on = "sa1_code")
age_vars = [n for n in sa1_ages.columns if "age" in n]
totals = sa1_ages[age_vars].sum(axis = "columns")
sa1_ages[age_vars] = (sa1_ages[age_vars]
                      .div(totals, axis = "index")
                      .round(3)) * 100
sa1_ages.head()


In [None]:
nrows, ncols = 5, 4
figure, axes = plt.subplots(figsize = (12, 15), 
                            nrows = nrows, ncols = ncols, 
                            layout = "constrained")
for i, var in enumerate(age_vars):
    ax = axes[i // ncols, i % ncols]
    sa1_ages.plot(ax = ax, column = var, legend = False,
                  vmin = 0, vmax = 25, cmap = "Reds")
    sa2.plot(ax = ax, fc = "#00000000", ec = "k", lw = 0.3)
    ax.set_axis_off()
    ax.set_title(f"{var}")
axes[4, 3].set_axis_off()
plt.show()

One wrinkle in this method is that if there is only one row or column in your figure then the indexing the axes uses only one value:

In [None]:
nrows, ncols = 1, 4
figure, axes = plt.subplots(figsize = (12, 3), 
                            nrows = nrows, ncols = ncols, 
                            layout = "constrained")
schemes = ["equalinterval", "quantiles", "prettybreaks", "fisherjenks"]

for i, scheme in enumerate(schemes):
    ax = axes[i]
    sa2.plot(ax = ax, column = "pop_density", cmap = "Reds", 
               scheme = scheme, k = 7,
               ec = "k", lw = 0.5, legend = False)
    ax.set_axis_off()
plt.show()