---
title: Utils for Matplotlib
---

In [None]:
# | default_exp matplotlib
# | export
from matplotlib.pyplot import Axes
from matplotlib.figure import Figure

In [None]:
# | export
def func2axes(axes: Axes | list[Axes], func, *args, **kwargs):
    if isinstance(axes, Axes):
        axes = [axes]
    for ax in axes:
        func(ax, *args, **kwargs)

In [None]:
# | export
# | hide
def unify_axis_fontsize(ax: Axes, fontsize: str | float):
    mainlabels = [ax.title, ax.xaxis.label, ax.yaxis.label]
    ticklabels = ax.get_xticklabels() + ax.get_yticklabels()
    for text in mainlabels + ticklabels:
        text.set_fontsize(fontsize)

In [None]:
# | export
def unify_axes_fontsize(
    axes: Axes,  # a single axis or a list of axes
    fontsize: str
    | float = "medium",  # string values denote sizes relative to the default font size
):
    """
    Set the fontsize of all text elements in a matplotlib axis to the same value.
    """
    func2axes(axes, unify_axis_fontsize, fontsize)

In [None]:
# | export
def hide_axis_legend(ax: Axes):
    ax.legend().set_visible(False)


def hide_axes_legend(axes: Axes | list[Axes]):
    func2axes(axes, hide_axis_legend)


def hide_fig_legend(fig: Figure):
    hide_axes_legend(fig.axes)

In [None]:
# | export
def hide_x_axis_label(ax: Axes):
    ax.set_xlabel("")


def hide_y_axis_label(ax: Axes):
    ax.set_ylabel(None)


def hide_x_axes_label(axes: Axes | list[Axes]):
    func2axes(axes, hide_x_axis_label)


def hide_y_axes_label(axes: Axes | list[Axes]):
    func2axes(axes, hide_y_axis_label)