# Matplotlib essentials for scientific visualization

By: Aitor Morales-Gregorio

This PyMotw we will have a look at [matplotlib](https://matplotlib.org/) and how to overcome some of the major challenges from using it. Previous knowledge of the packages is assumed (basic level), so we can move directly into the nitty gritty annoying stuff.

This tutorial is heavily inspired by the book [Scientific Visualization](https://inria.hal.science/hal-03427242/document) by Nicolas P. Rougier.


### Anatomy of a figure (from book)

In [None]:
%matplotlib inline

# ----------------------------------------------------------------------------
# Title:   Scientific Visualisation - Python & Matplotlib
# Author:  Nicolas P. Rougier
# License: BSD
# ----------------------------------------------------------------------------
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator, MultipleLocator, FuncFormatter

np.random.seed(123)

X = np.linspace(0.5, 3.5, 100)
Y1 = 3 + np.cos(X)
Y2 = 1 + np.cos(1 + X / 0.75) / 2
Y3 = np.random.uniform(Y1, Y2, len(X))

fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(1, 1, 1, aspect=1)


def minor_tick(x, pos):
    if not x % 1.0:
        return ""
    return "%.2f" % x


ax.xaxis.set_major_locator(MultipleLocator(1.000))
ax.xaxis.set_minor_locator(AutoMinorLocator(4))
ax.yaxis.set_major_locator(MultipleLocator(1.000))
ax.yaxis.set_minor_locator(AutoMinorLocator(4))
ax.xaxis.set_minor_formatter(FuncFormatter(minor_tick))

ax.set_xlim(0, 4)
ax.set_ylim(0, 4)

ax.tick_params(which="major", width=1.0)
ax.tick_params(which="major", length=10)
ax.tick_params(which="minor", width=1.0, labelsize=10)
ax.tick_params(which="minor", length=5, labelsize=10, labelcolor="0.25")

ax.grid(linestyle="--", linewidth=0.5, color=".25", zorder=-10)

ax.plot(X, Y1, c=(0.25, 0.25, 1.00), lw=2, label="Blue signal", zorder=10)
ax.plot(X, Y2, c=(1.00, 0.25, 0.25), lw=2, label="Red signal")
ax.plot(X, Y3, linewidth=0, marker="o", markerfacecolor="w", markeredgecolor="k")

ax.set_title("Anatomy of a figure", fontsize=20, verticalalignment="bottom")
ax.set_xlabel("X variable")
ax.set_ylabel("Y variable")

ax.legend(loc='upper right')


def circle(x, y, radius=0.15):
    from matplotlib.patches import Circle
    from matplotlib.patheffects import withStroke

    circle = Circle(
        (x, y),
        radius,
        clip_on=False,
        zorder=10,
        linewidth=1,
        edgecolor="black",
        facecolor=(0, 0, 0, 0.0125),
        path_effects=[withStroke(linewidth=5, foreground="w")],
    )
    ax.add_artist(circle)


def text(x, y, text):
    ax.text(
        x,
        y,
        text,
        backgroundcolor="white",
        # fontname="Yanone Kaffeesatz", fontsize="large",
        ha="center",
        va="top",
        weight="regular",
        color="#000099",
    )


# Minor tick
circle(0.50, -0.10)
# text(0.50, -0.32, "Minor tick label")

# Major tick
circle(-0.03, 4.00)
# text(0.03, 3.80, "Major tick")

# Minor tick
circle(0.00, 3.50)
# text(0.00, 3.30, "Minor tick")

# Major tick label
circle(-0.15, 3.00)
# text(-0.15, 2.80, "Major tick label")

# X Label
circle(1.80, -0.27)
# text(1.80, -0.45, "X axis label")

# Y Label
circle(-0.27, 1.80)
# text(-0.27, 1.6, "Y axis label")

# Title
circle(1.60, 4.13)
# text(1.60, 3.93, "Title")

# Blue plot
circle(1.75, 2.80)
# text(1.75, 2.60, "Line\n(line plot)")

# Red plot
circle(1.20, 0.60)
# text(1.20, 0.40, "Line\n(line plot)")

# Scatter plot
circle(3.20, 1.75)
# text(3.20, 1.55, "Markers\n(scatter plot)")

# Grid
circle(3.00, 3.00)
# text(3.00, 2.80, "Grid")

# Legend
circle(3.70, 3.80)
# text(3.70, 3.60, "Legend")

# Axes
circle(0.5, 0.5)
# text(0.5, 0.3, "Axes")

# Figure
circle(-0.3, 0.65)
# text(-0.3, 0.45, "Figure")

color = "#000099"
ax.annotate(
    "",
    xy=(4.0, 0.35),
    xytext=(3.3, 0.5),
    color=color,
    weight="regular",  # fontsize="large", fontname="Yanone Kaffeesatz",
    arrowprops=dict(arrowstyle="->", connectionstyle="arc3", color=color),
)

# ax.annotate(
#     "Spines",
#     xy=(3.15, 0.0),
#     xytext=(3.45, 0.45),
#     color=color,
#     weight="regular", # fontsize="large", fontname="Yanone Kaffeesatz",
#     arrowprops=dict(arrowstyle="->", connectionstyle="arc3", color=color),
# )

plt.show()

### Creating your default style

It is recommended to use *Style Sheets* to define your style in matplotlib, because this helps you have consistent figures your papers.

There are some [default style sheets](https://matplotlib.org/stable/gallery/style_sheets/style_sheets_reference.html), please avoid these...

#### Custom style file

This looks something like `xxx.mplstyle` and it is plaintext like this:
```
# Font and labelsizes
xtick.labelsize : 7
ytick.labelsize : 7
axes.labelsize : 8
legend.fontsize : 7
axes.titlesize : 9

# Fonts
font.family : sans-serif
font.sans-serif : Arial

# Linestyle
axes.linewidth : 0.5

# Remove top and right spines
axes.spines.right: False
axes.spines.top: False

# Rendering params
savefig.facecolor: 'w'
axes.unicode_minus: False

# Tick sizes
xtick.major.size : 1
xtick.major.width : 0.5
ytick.major.size : 1
ytick.major.width : 0.5
xtick.major.pad : 1
ytick.major.pad : 1
xtick.minor.size : 0.7
xtick.minor.width : 0.4
ytick.minor.size : 0.7
ytick.minor.width : 0.4

```

The full list of all possible parameters and extensive explanations is provided as [the matplotlibrc file](https://github.com/matplotlib/matplotlib/blob/v3.8.0/lib/matplotlib/mpl-data/matplotlibrc).

You can use your style by once running:

In [None]:
plt.style.use('paper.mplstyle')

In [None]:
np.random.seed(123)

X = np.linspace(0.5, 3.5, 100)
Y1 = 3 + np.cos(X)
Y2 = 1 + np.cos(1 + X / 0.75) / 2
Y3 = np.random.uniform(Y1, Y2, len(X))

fig = plt.figure(figsize=(4, 3))
ax = fig.add_subplot(1, 1, 1)

ax.xaxis.set_major_locator(MultipleLocator(1.000))
ax.yaxis.set_major_locator(MultipleLocator(1.000))

ax.set_xlim(0, 4)
ax.set_ylim(0, 4.2)

ax.grid(linestyle="--", linewidth=0.5, alpha=0.4, color=".25", zorder=-10)

ax.plot(X, Y1, c=(0.25, 0.25, 1.00), label="Blue signal", zorder=10)
ax.plot(X, Y2, c=(1.00, 0.25, 0.25), label="Red signal")
ax.plot(X, Y3, linewidth=0, marker="o", markersize=5, markerfacecolor="w", markeredgecolor="k")

ax.set_title("Anatomy of a figure", verticalalignment="bottom")
ax.set_xlabel("X variable")
ax.set_ylabel("Y variable")

ax.legend(loc='upper right', frameon=False)

### Placing axes where you want them to be

#### Mosaic

The [`plt.subplot_mosaic`](https://matplotlib.org/stable/users/explain/axes/mosaic.html#mosaic) is your friend.


In [None]:
fig, axs = plt.subplot_mosaic('''
                              AAA
                              B.C
                              ''')

for ax in axs:
    axs[ax].text(0.5, 0.5, ax, fontsize=30, c='gray', ha='center', va='center')

#### Insets

The [`ax.inset_axes`](https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.inset_axes.html#matplotlib.axes.Axes.inset_axes) can be very useful to create zoom-ins, place colorbars precisely, include diagrams to a figure...

In [None]:
axin = axs['A'].inset_axes([0.6, 0.6, 0.3, 0.3])

#### [EXTRA] Subfigures
They can be very useful in some cases, e.g. certain part of the figure is a (large) regular subplot and you don't want to type in all the names with a mosaic.

In [None]:



fig = plt.figure(layout="constrained")
left, right = fig.subfigures(nrows=1, ncols=2)
# left.subplots(3,3)
# right.subplot_mosaic('''
#                      AB.
#                      CCC
#                      DDD
#                      ''')
plt.show()

### Color

DO NOT USE DEFAULT COLORS!

Matplotlib has a [long list of named colors](https://matplotlib.org/stable/gallery/color/named_colors.html) so you do not have to fight with RBG values.

TRANSPARENCY IS YOUR FRIEND (THE ALPHA PARAMETER)

In [None]:
import math

import matplotlib.pyplot as plt

import matplotlib.colors as mcolors
from matplotlib.patches import Rectangle


def plot_colortable(colors, *, ncols=4, sort_colors=True):

    cell_width = 212
    cell_height = 22
    swatch_width = 48
    margin = 12

    # Sort colors by hue, saturation, value and name.
    if sort_colors is True:
        names = sorted(
            colors, key=lambda c: tuple(mcolors.rgb_to_hsv(mcolors.to_rgb(c))))
    else:
        names = list(colors)

    n = len(names)
    nrows = math.ceil(n / ncols)

    width = cell_width * ncols + 2 * margin
    height = cell_height * nrows + 2 * margin
    dpi = 72

    fig, ax = plt.subplots(figsize=(width / dpi, height / dpi), dpi=dpi)
    fig.subplots_adjust(margin/width, margin/height,
                        (width-margin)/width, (height-margin)/height)
    ax.set_xlim(0, cell_width * ncols)
    ax.set_ylim(cell_height * (nrows-0.5), -cell_height/2.)
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.set_axis_off()

    for i, name in enumerate(names):
        row = i % nrows
        col = i // nrows
        y = row * cell_height

        swatch_start_x = cell_width * col
        text_pos_x = cell_width * col + swatch_width + 7

        ax.text(text_pos_x, y, name, fontsize=14,
                horizontalalignment='left',
                verticalalignment='center')

        ax.add_patch(
            Rectangle(xy=(swatch_start_x, y-9), width=swatch_width,
                      height=18, facecolor=colors[name], edgecolor='0.7')
        )

    return fig

plot_colortable(mcolors.CSS4_COLORS)
plt.show()

Also, XKCD (yes, the [online comic thing](https://xkcd.com/)) colors are supported, for an endless list to keep you entertained.

![](best_tasting_colors.png)

In [None]:
xkcd_fig = plot_colortable(mcolors.XKCD_COLORS)
plt.show()

In [None]:
# Example usage
plt.plot(X, Y1, c='xkcd:vomit yellow', lw=2, ls='--')

#### Colormaps

The [default colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html) in matplotlib are very good. They even consider color blindness, perceived value and provide some notion of perceptual continuity to ensure accurate data representation.

In [None]:
from colorspacious import cspace_converter

import matplotlib.pyplot as plt
import numpy as np

import matplotlib as mpl

mpl.rcParams.update({'font.size': 14})

# Indices to step through colormap.
x = np.linspace(0.0, 1.0, 100)

gradient = np.linspace(0, 1, 256)
gradient = np.vstack((gradient, gradient))


def plot_color_gradients(cmap_category, cmap_list):
    fig, axs = plt.subplots(nrows=len(cmap_list), ncols=2)
    fig.subplots_adjust(top=0.95, bottom=0.01, left=0.2, right=0.99,
                        wspace=0.05)
    fig.suptitle(cmap_category + ' colormaps', fontsize=14, y=1.0, x=0.6)

    for ax, name in zip(axs, cmap_list):

        # Get RGB values for colormap.
        rgb = mpl.colormaps[name](x)[np.newaxis, :, :3]

        # Get colormap in CAM02-UCS colorspace. We want the lightness.
        lab = cspace_converter("sRGB1", "CAM02-UCS")(rgb)
        L = lab[0, :, 0]
        L = np.float32(np.vstack((L, L, L)))

        ax[0].imshow(gradient, aspect='auto', cmap=mpl.colormaps[name])
        ax[1].imshow(L, aspect='auto', cmap='binary_r', vmin=0., vmax=100.)
        pos = list(ax[0].get_position().bounds)
        x_text = pos[0] - 0.01
        y_text = pos[1] + pos[3]/2.
        fig.text(x_text, y_text, name, va='center', ha='right', fontsize=10)

    # Turn off *all* ticks & spines, not just the ones with colormaps.
    for ax in axs.flat:
        ax.set_axis_off()

    plt.show()

plot_color_gradients('Perceptually Uniform Sequential',
                     ['viridis', 'plasma', 'inferno', 'magma', 'cividis'])

plot_color_gradients('Sequential',
                     ['Greys', 'Purples', 'Blues', 'Greens', 'Oranges', 'Reds',
                      'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu', 'BuPu',
                      'GnBu', 'PuBu', 'YlGnBu', 'PuBuGn', 'BuGn', 'YlGn'])

plot_color_gradients('Sequential (2)',
                     ['binary', 'gist_yarg', 'gist_gray', 'gray', 'bone',
                      'pink', 'spring', 'summer', 'autumn', 'winter', 'cool',
                      'Wistia', 'hot', 'afmhot', 'gist_heat', 'copper'])

plot_color_gradients('Diverging',
                     ['PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu', 'RdYlBu',
                      'RdYlGn', 'Spectral', 'coolwarm', 'bwr', 'seismic'])

plot_color_gradients('Cyclic', ['twilight', 'twilight_shifted', 'hsv'])

plot_color_gradients('Qualitative',
                     ['Pastel1', 'Pastel2', 'Paired', 'Accent', 'Dark2',
                      'Set1', 'Set2', 'Set3', 'tab10', 'tab20', 'tab20b',
                      'tab20c'])

plot_color_gradients('Miscellaneous',
                     ['flag', 'prism', 'ocean', 'gist_earth', 'terrain',
                      'gist_stern', 'gnuplot', 'gnuplot2', 'CMRmap',
                      'cubehelix', 'brg', 'gist_rainbow', 'rainbow', 'jet',
                      'turbo', 'nipy_spectral', 'gist_ncar'])

# for cmap_category, cmap_list in cmaps.items():

#     plot_color_gradients(cmap_category, cmap_list)

# Simple animations

It is very very easy to create [animations with matplotlib](https://matplotlib.org/stable/users/explain/animations/animations.html#saving-animations). The syntax is a bit convoluted, but you just need to remember that all things in matplotlib are objects and if you can modify their properties (angle, color, position...) then you can create an animation by writing a suitable `update()` function.


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.animation as animation

In [None]:
%matplotlib tk

fig, ax = plt.subplots()
x = np.linspace(0, 50, 2000)
y = np.sin(x)

line = ax.plot(x, y)

def update(frame):
    # get line, update y values by shifting them
    line = ax.get_lines()[0]
    line.set_ydata(np.sin(x+frame))
    return (line)


ani = animation.FuncAnimation(fig=fig, func=update, frames=10, interval=30)
plt.show()

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')

# This is a trefoil knot
x = np.arange(0, 2*np.pi, step=0.01)
y1 = np.cos(x) + 2*np.cos(2*x)
y2 = np.sin(x) - 2*np.sin(2*x)
y3 = 2*np.sin(3*x)

# Plot and set initial view angle
ax.plot(y1, y2, y3)
ax.view_init(azim=0)

ax.set_title('Trefoil knot')
ax.set_xlabel('y1')
ax.set_ylabel('y2')
ax.set_zlabel('y3')

def update(frame):
    # Update angle each frame
    ax.view_init(azim=frame)

ani = animation.FuncAnimation(fig=fig, func=update, frames=500, interval=10)
plt.show()

In [None]:
# To save animations! 
# Avoid .GIF format, it is very inefficient
FFwriter = animation.FFMpegWriter(fps=20)  # Frames per second, you can control the speed
video_path = 'animation.mp4'
ani.save(video_path, writer=FFwriter, dpi=400)