# Getting started with Matplotlib and `segretini-matplottini`, with a barplot

If you are here, chances are that you are tired of dull spreadsheet-like plots, and you want to bring your visualization skills to the next level to obtain a fabled `the paper's plots are so beautiful` comment from Reviewer #2.

🚂 Let's get started! In this notebook, we will see an introduction to [Matplotlib](https://matplotlib.org) and [Seaborn](https://seaborn.pydata.org), but most importantly to some of the key concepts to keep in mind when creating a great visualization. We will also see how `segretini-matplottini` can make our life easier, speeding up some of the complexity related to bringing your plots up to shape.

In the end, we will no longer create plots like the first one, and we will be able to create plots like the second one.

<div>
<img src="../plots/notebooks/1_getting_started_with_barplots/the_ugliest_barplot.png" height="300"/>
<img src="../plots/notebooks/1_getting_started_with_barplots/a_much_better_barplot.png" height="300"/>
</div>

## Exploratory visualization vs message-delivery plotting

Not all visualizations are created with the same goal. As such, not all plots have the same needs.
A major distinction happens between the following two categories: **exploratory visualization** and **message-delivery plotting**, a name I just made up and that will be clear in a second.

### 🔎 Exploratory visualization
The idea of exploratory visualization is that you have some data that you know very little about, and you want to learn something more about it. The main point is that **you don't know yet what you are looking for**. You might have a hint, but the unknowns are more than what's known.

The most common situation is *exploratory data analysis*: you have a dataset, and want to understand something about it, e.g. class distributions, with the goal of building a predicting model that can leverage this newly discovered information.
Another common situation is where you run a very complex experiment that tracks many metrics over multiple benchmark datasets. You don't know in advance what to expect, and you can use an exploratory visualization to make order in the chaos.

Exploratory visualization lies in the world of fast prototyping. The main audience of an exploratory visualization is **yourself** (or a narrow pool of people who are knowledgeable about the topic). **The goal is to use visualizations to understand something new.**

### ⭐️ Message-delivery plotting

In message-delivery plotting, you want to convey a message to an external target audience, and you need to convince this audience that your message is true. 
For example, your algorithm is more accurate than all the existing implementations of that same algorithm. 

How to best deliver the message is up to you: a sentence, a table, or a visualization. In many cases, a well-made visualization is the most impactful approach.

In this situation, you know exactly what you want to show, and your goal is to show it in the **clearest way possible**. It has to be polished, curated, and clutter-free. It takes a long time, trial and error, and patience.

Message-delivery plotting is the focus of this notebook, and of `segretini-matplottini`. We don't focus on exploratory visualization. If you want to know more about exploratory visualization, there are plenty of valid options, such as Seaborn's [`FacetGrid`](https://seaborn.pydata.org/tutorial/axis_grids.html) and [Plotly](https://plotly.com/python/).

## Getting started

To run the following code, install `segretini-matplottini` as explained in the [README](../README.md). As a reminder, the following should suffice.

```shell
git clone https://github.com/AlbertoParravicini/segretini-matplottini.git
cd segretini-matplottini
pip install ".[notebook]"
jupyter notebook notebooks/1_getting_started_with_barplots.ipynb
```

# Building a plot with Matplotlib

Matplotlib's plots are created by building a `Figure` and one or more `Axes` within the `Figure`. Each `Axes` is a single plot (for example, a barplot), and a `Figure` can contain multiple `Axes` (for example, a grid of barplots). The following code shows how to build the simplest plot, which, for now, is empty.

In [None]:
# This is the main plotting library in Matplotlib;
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.axes import Axes

# Reset Matplotlib to its default style settings;
plt.rcdefaults()

# Build a Figure containing a single Axes;
fig: Figure
ax: Axes
fig, ax = plt.subplots(nrows=1, ncols=1)

# Show the empty plot;
plt.show()

Let's unpack what we did above 🙌

First, we called `plt.rcdefaults()`. That's a bit out of the blue, but it's a very important command to remember. Matplotlib keeps a global state of style-related settings (we'll see later how to modify them), and this command ensures that everything is brought back to its default value before creating a plot. Without `plt.rcdefaults()`, it is easy to override some settings when creating multiple plots in the same script. Calling `plt.rcdefaults()` ensures that we always start from a blank slate.

Then, we created a `Figure` with `plt.figure()`, and the figure contains a single `Axes` (that is, an `x-axis` and a `y-axis`). 
* We can pass additional arguments to `plt.subplots` to control the number of `Axes`, their layout, and the resolution of the plot. We'll see some of these options later.
    
Finally, we called `plt.show()` to display the plot. 
* The command can be omitted when working in a Jupyter notebook, as the plot will be displayed automatically. However, it is necessary when working in a Python script. 
* An even better practice, when writing a script, would be to call `fig.show()` instead of `plt.show()`, as it explicitly specifies which is the `Figure` to display.

Indeed, Matplotlib provides some shorthands to access the latest `Figure` and `Axes` available. 
  * To access the latest `Figure`, we can use `plt.gcf()`, and to access the latest `Axes`, we can use `plt.gca()`.
  * One can call `plt.show()` instead of `fig.show()`. That's equivalent to calling `plt.gcf().show()`.
  * Similarly, one can call `plt.save()` instead of `fig.save()`, or `plt.plot()` instead of `ax.plot()`.
  * My advice is to always use the `Figure` and `Axes` objects explicitly, as it reduces the chance of errors when working with multiple `Figures` and `Axes`.

# Let's load some data

We will use a simple dataset in this notebook. It contains the estimated quality of different predictive models over a few different datasets.

Remember that we are not doing an exploratory visualization, but we want to visualize some known results to convey a message. We know in advance the structure of the data we are loading, and we most likely know what to expect from this data. 
For example, we know that a certain model gave better predictions than another, and we want to show it in our plot.

We use `pandas` to handle datasets since we are going to create plots that benefit from a tabular representation.

In [None]:
import pandas as pd

# Load the data from the CSV;
data = pd.read_csv("../data/notebooks/1_getting_started_with_barplots/barplot_data.csv")

# This prints the dataset in a nice format, in the notebook;
data

If we look at the data, we see that there are three models and five datasets. In our experiments, model `A` is a baseline, which performs quite poorly. Model `B` is the current state-of-the-art, and model `C` is our new amazing model.
Overall, model `C` is the best, when averaged across all datasets, and we want to show that!

In [None]:
# Average the results of each model across datasets, and print the best model first;
data.groupby("model").mean(numeric_only=True).sort_values("value", ascending=False)

# Our first barplot

We use seaborn since it provides simple ways to plot using dataframes, and to group the plot elements (e.g. the bars) according to our needs. 

We can group the models by experiment so that we can easily compare the models within each experiment.

Seaborn operates on top of Matplotlib, so we can create the plot within an existing `Axes`, and the output will also be an `Axes` that we can manipulate using any of the functionalities offered by Matplotlib. 

You could achieve the same result without using Seaborn, but it would be significantly more complex. Indeed, most of the plots in `segretini-matplottini` are created using Seaborn under the hood. Not all of them though! Check out [`timeseries`](../examples/plot_timeseries.py) or [`binary_classification`](../examples/plot_binary_classification.py) for examples of complex plots that are created with vanilla Matplotlib.

In [None]:
import seaborn as sns

# Reset Matplotlib to its default style settings;
plt.rcdefaults()

# Build a Figure containing a single Axes;
fig: Figure
ax: Axes
fig, ax = plt.subplots(nrows=1, ncols=1)

# Add a barplot with Seaborn on the Axes we just created;
sns.barplot(data=data, x="dataset", y="value", hue="model", ax=ax)

# Show the barplot;
plt.show()

## Things we do not like

Our barplot, out of the box, does not look that bad. It is arguably better than the dreaded spreadsheet plot that we started with. Less wasted white space, and more readable labels. But there are still many things that we can improve.

Let's list some areas of improvement that a trained plot-master would spot in a second.
* Is our model really the best? Sometimes it is better, sometimes it is worse. **The message that our model is the best is not clear, and that's the highest priority fix to make.**
* The legend is not informative. What's our model? What's `C`? Our model is `C`, but that's not written anywhere. 
* The magnitude of the differences is unclear. The y-axis has labeled ticks, but at a glance, it's not possible to say that model `B` is 10% or 20% better or worse than something else.
* Does a higher bar even correspond to better quality?
* Colors are not meaningful.
* The order of the bar groups is wrong. Why would `dataset_10` appear before `dataset_1`?
* The size of the plot is likely to be wrong. If the goal is to put the plot in a publication or in some other place where space is limited, we need to guarantee that the plot uses the available space efficiently. 
* Overall, the style of the plot is very *vanilla*. Nothing wrong with that, but a bit of visual flair can make our visualization stand out from the crowd.

Before diving deeper into these issues, it is a good moment to mention two important rules of thumb.

### Just one message

If you have ten datasets, it might be tempting to show the performance of the models across all datasets. The result will be very confusing and hard to read. Instead, showing the average performance across all datasets will deliver the message more intuitively. Readers will spot the best model at a glance, instead of puzzling over the graph for minutes.

When creating a visualization, ask yourself **"What is the message that I want to convey?"**. Everything has to be aligned with the message. Everything else has to go.

### More is better, but only if it helps your message. 

This second point is strictly connected to the previous one. Adding more information to your plot, such as an arrow that highlights the best-performing model, a label that mentions how a higher score is better, or a grid that helps compare distant bars, can make the delivery of your message more powerful.

But adding elements for the sake of it, such as error bars in experiments where variance is minimal and all differences are statistically significant, or results divided across a dozen different datasets will counterintuitively weaken your message by making the plot harder to read.

Truth be told, many amazing papers have plots just like the one we just drew, and that did not weaken the quality of their research at all. But there are also many cases where a better plot can make your achievements clearer, and tip the balance in your favor.

# Making a better barplot

In the following sections, we will dissect the points in **"Thinks we do not like"**, and create a better barplot, one improvement at a time. Let's begin! 🤘

## Making our message stronger

The message that we want to deliver, in our plot, is that we have a new model whose quality is significantly better than the current state-of-the-art. The previous plot did not convey this message well. Sometimes our model is the best, but sometimes it is slightly worse. Letting the readers compute averages in their heads, or draw conclusions by themselves, weakens the delivery of your message. 

### Showing the average performance of each model

The most impactful improvement we can make to our plot is to show the average performance of each model, across all datasets. Just like in our previous `DataFrame`, it will be obvious that model `C` (our model) is the best of the bunch.

In [None]:
# Load the data from the CSV;
data = pd.read_csv("../data/notebooks/1_getting_started_with_barplots/barplot_data.csv")

# Average the results of each model across datasets;
average_results: pd.DataFrame = data.groupby("model").mean(numeric_only=True)

# We want to create a new DataFrame that contains the average results.
# To do so, we need to add the `model` and `dataset` columns back.
# The model column is already available, and it is the index of `average_results`.
# Let's make it a column again;
average_results = average_results.reset_index()
# The `dataset` column will have the value `Average` for all models, instead of being a specific dataset;
average_results["dataset"] = "Average"

# Create a new DataFrame that contains the average results.
# The DataFrame with the average results is added first, so that it is plotted first;
data = pd.concat([average_results, data])

# Plot the results, like we did before;
plt.rcdefaults()
fig, ax = plt.subplots(nrows=1, ncols=1)
sns.barplot(data=data, x="dataset", y="value", hue="model", ax=ax)

# Show the barplot;
plt.show()

A simple change can go a long way! Now it's obvious that model `C` has the highest bars on average, and model `A` has the lowest bars on average. 

But we are not done yet. Which one is our model? Does a higher bar denote a better model? How much are the differences between the models? Let's fix these problems one at a time.

### A better legend

Based on the nomenclature we chose in our experiments, we know that model `A` is a baseline, model `B` is the state-of-the-art, and model `C` is our model. But a reader will not know that! Let's use better names for our models.

There are multiple ways to do this. We could rename the values directly in the `DataFrame`, as follows.

```python
data["model"] = data["model"].replace({"A": "Baseline", "B": "State-of-the-art", "C": "Our model"})
```

This works, but we have other options that do not require modifications to our dataset. The advantage is that we have more flexibility. For example, we can refer to the same model in multiple ways, in different parts of the code. Also, it gives the chance to show how Matplotlib works.

A Matplotlib legend contains `handles` and `labels`, and each `handle` corresponds to a `label`. A `handle` is the colored rectangle that identifies an element of the plot (a bar type, in our case). The `label` is the text that goes with the `handle`.

To update the entries of the legend, we first obtain the current `handles` and `labels`, then replace the `labels` with new ones.

In [None]:
# Plot the results, like we did before;

plt.rcdefaults()
fig, ax = plt.subplots(nrows=1, ncols=1)
ax: Axes = sns.barplot(data=data, x="dataset", y="value", hue="model", ax=ax)

# Update the existing legend labels to have more meaningful names;
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=["Baseline", "State-of-the-art", "Our model"])

# Show the barplot;
plt.show()

### Is higher better?

From our plot, it is not immediately clear whether a higher bar corresponds to a better or worse outcome. It is intuitively true (we would not show it if the baseline was better than the other models), but there are many cases where the situation is not so clear-cut. For example, if you have multiple plots with different metrics, with some that have to be high and some that have to be low. In that case, helping the reader to understand the direction of improvement of each metric can make the plot significantly easier to read.

How do we improve the situation? Sometimes, having a meaningful y-axis label could be enough. A model's accuracy should always be as high as possible, and latency should always be as low as possible. But what if you have a very domain-specific metrics such as [NDCG@K](https://en.wikipedia.org/wiki/Discounted_cumulative_gain) or [LPIPS](https://arxiv.org/pdf/1801.03924.pdf)? Not everyone might be immediately familiar with their meaning. Also, what if you have a metric whose direction of improvement is context-dependent? For example, *CPU utilization*. If you are building a distributed job scheduler, you most likely want to maximize the CPU utilization of as few machines as possible. On the other hand, if you are building a low-profile task monitor, you want its CPU utilization to be as low as possible.

So, let's think of a better alternative. The simplest option is to just specify the direction of improvement in the y-label.

In [None]:
# Plot the results, like we did before;
plt.rcdefaults()
fig, ax = plt.subplots(nrows=1, ncols=1)
ax: Axes = sns.barplot(data=data, x="dataset", y="value", hue="model", ax=ax)

# Update the existing legend labels to have more meaningful names;
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=["Baseline", "State-of-the-art", "Our model"])

# Update the y-axis label;
ax.set_ylabel("Value (higher is better)")
# Let's also update the x-axis label;
ax.set_xlabel("Dataset")

# Show the barplot;
plt.show()

Often, this is enough. But the rotated text might be hard to read, or you might not have enough vertical space for a long label. Let's see another, more complex, approach: adding a vertical arrow to indicate the direction of improvement. We will do this through a separate function so that it can be reused later without repeating a lot of code. 

By the way, a similar function is available in `segretini-matplottini`, check out [`add_arrow_to_barplot`][add_arrow_to_barplot].

[add_arrow_to_barplot]: ../segretini_matplottini/utils/plot.py#L125

### 💡 Quick introduction to coordinates

When adding a new element to a plot, such as a textual label, you need to specify its coordinates. Coordinates in Matplotlib are a complex topic ([link](https://matplotlib.org/stable/tutorials/advanced/transforms_tutorial.html) for more information), but we can simplify it as follows.

* `data` coordinates are the default, and what you obtain when calling `ax.get_xlim()` or `ax.get_ylim()`. In our plot, `get_ylim` gives `[0, 1.54]`. The values of `get_xlim` are less straightforward since we are plotting bars and not a set of points, but it is usually similar to the number of bar groups, plus a bit of padding. The value of `get_xlim` will be something around `[-0.5, 5.5]` since we have six bar groups. Using `data` coordinates is useful when adding elements relative to some existing plot item, e.g. labels on top of the bars.
* `axes` coordinates are given as a fraction of the axes' size. The value `[0, 0]` is the bottom left, and the value `[1, 1]` is the top right. Using `axes` coordinates is useful when adding elements whose position does not depend on the data, e.g. placing a legend in the top right corner.
* `xaxis` (and `ax.get_xaxis_transform()`) denote a blended coordinate system. The x-axis is in data coordinates, and the y-axis is in axes coordinates. This system is used in the example below, where we need to add an arrow that is placed to the left of the bars, but it has to span a certain height fraction regardless of how tall the bars are.
* `figure` coordinates are given as a fraction of the figure size. The value `[0, 0]` is the bottom left, and the value `[1, 1]` is the top right. Using `figure` coordinates is useful when placing elements outside of the axes, especially if you have multiple subplots. For example, if you have two subplots and you want to place a legend that works for both of them at the bottom center of the figure.

In [None]:
from matplotlib.patches import Rectangle


def higher_is_better_arrow(ax: Axes, linewidth: float = 1) -> Axes:
    # We need to create a bit of whitespace to the left of the plot, to add the arrow.
    # `ax.get_xlim()` returns a tuple with the left and right limits of the x-axis,
    # and we update it with `ax.set_xlim()`. The right limit is unchanged;
    ax.set_xlim(ax.get_xlim()[0] - 0.3, ax.get_xlim()[1])
    # To draw an arrow, we need to know the x and y coordinates of the start and end points.
    # The x coordinate is halfway through the x-axis left limit, and the start of the first bar.
    # `ax.patches` returns a list of all the patches present in the plot.
    # In our case, all the patches are `Rectangles`, since we have a barplot;
    first_bar: Rectangle = ax.patches[0]
    x_coord: float = first_bar.get_x() - (first_bar.get_x() - ax.get_xlim()[0]) / 2
    # The y start coordinate is at 1% of the height, the end coordinate at 99% of the height.
    # By default, coordinates in Matplotlib are given in data coordinates.
    # To specify the y coordinates as a fraction of the height, we specify `xycoords=ax.get_yaxis_transform()`.
    # According to Matplotlib, "The x-direction is in data coordinates and the y-direction is in axis coordinates."
    y_start: float = 0.01
    y_end: float = 0.99
    ax.annotate(
        "",
        xy=(x_coord, y_end),
        xytext=(x_coord, y_start),
        arrowprops=dict(
            arrowstyle="->",
            color="#2f2f2f",  # Slightly gray, looks better than pure black;
            linewidth=linewidth,
        ),
        xycoords=ax.get_xaxis_transform(),
    )
    return ax


# Plot the results, like we did before;
plt.rcdefaults()
fig, ax = plt.subplots(nrows=1, ncols=1)
ax: Axes = sns.barplot(data=data, x="dataset", y="value", hue="model", ax=ax)
# Update the existing legend labels to have more meaningful names;
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=["Baseline", "State-of-the-art", "Our model"])

# Update the x-axis and y-axis labels;
ax.set_xlabel("Dataset")
ax.set_ylabel("Value")

# Add the arrow to denote that higher values are better;
ax = higher_is_better_arrow(ax)

# Show the barplot;
plt.show()

## How much is the difference between bars?

Now we are communicating that our model is the best. Still, by how much is it the best? 5%? 10%? Hard to spot at a glance. Let's see how we can fix this.

### Adding a grid

The simplest improvement we can make is to add a grid to the y-axis. It will make it easier to connect the values on the y-axis to the height of the bars.

In [None]:
# Plot the results, like we did before;
plt.rcdefaults()
fig, ax = plt.subplots(nrows=1, ncols=1)

ax: Axes = sns.barplot(data=data, x="dataset", y="value", hue="model", ax=ax)

# Update the existing legend labels to have more meaningful names;
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=["Baseline", "State-of-the-art", "Our model"])

# Update the x-axis and y-axis labels;
ax.set_xlabel("Dataset")
ax.set_ylabel("Value")

# Add the arrow to denote that higher values are better;
ax = higher_is_better_arrow(ax)

# Make sure that the grid is drawn below the bars, instead of above;
ax.set_axisbelow(True)
# Add a grid to the y-axis;
ax.grid(axis="y", linestyle="--", linewidth=0.5)

# Show the barplot;
plt.show()

### Adding labels to bar plots

While grids help identify the approximate magnitude of differences, we can be even more explicit, and directly add the values of each bar on top of it. To do so, we iterate over the bars, obtain their height, and add a textual label on top of the bar.

Once again, we create a separate function so that it will be easy to reuse. The vertical padding between the top of each bar and the label usually requires a bit of trial-and-error adjustment. The default value is sensible, but we leave it as an input parameter so that you can play with it.
The same happens for the font size of the label. If left `None`, Matplotlib will pick a default value. However, the default value might be too large or too small, so we might also need a bit of trial and error.

For completeness, Matplotlib offers an `ax.bar_label` function that does something similar ([link](https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.bar_label.html)). It provides out-of-the-box support for error bars, but it makes it harder to perform advanced customizations, like plotting relative performance instead of absolute values.

In [None]:
from typing import Optional


def add_labels_to_bars(ax: Axes, vertical_offset_points: float = 0.5, font_size: Optional[int] = None) -> Axes:
    # Keep only patches that correspond to bars.
    # Not strictly necessary since in our case all patches are bars,
    # but we want to be safe;
    bars: list[Rectangle] = [p for p in ax.patches if isinstance(p, Rectangle)]
    for bar in bars:
        # The label is placed at the center x-coordinate of a bar,
        # using data coordinates;
        label_x_coordinate = bar.get_x() + bar.get_width() / 2
        # The label is placed at the top of a bar,
        # with a vertical offset whose size is expressed in points;
        height = bar.get_height()
        label_y_coordinate = height
        ax.annotate(
            text=f"{height:.2f}",  # Our label;
            xy=(label_x_coordinate, label_y_coordinate),  # Coordinates of the label, in data-coordinates;
            xytext=(0, vertical_offset_points),  # Coordinates of the text, as offset points w.r.t. `xy`;
            textcoords="offset points",
            ha="center",  # Horizontal alignment;
            va="bottom",  # Vertical alignment;
            color="#2f2f2f",  # Text color, slightly gray;
            fontsize=font_size,
        )
    return ax


# Plot the results, like we did before;
plt.rcdefaults()
fig, ax = plt.subplots(nrows=1, ncols=1)

ax: Axes = sns.barplot(data=data, x="dataset", y="value", hue="model", ax=ax)

# Update the existing legend labels to have more meaningful names;
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=["Baseline", "State-of-the-art", "Our model"])

# Update the x-axis and y-axis labels;
ax.set_xlabel("Dataset")
ax.set_ylabel("Value")

# Add the arrow to denote that higher values are better;
ax = higher_is_better_arrow(ax)

# Make sure that the grid is drawn below the bars, instead of above;
ax.set_axisbelow(True)
# Add a grid to the y-axis;
ax.grid(axis="y", linestyle="--", linewidth=0.5)

# Add labels with absolute values on top of each bar;
ax = add_labels_to_bars(ax)

# Show the barplot;
plt.show()

See how the labels are too large and overlap with each other? 

This is a sad truth of plotting: **as much as you can automate, there will always be a lot of manual fine-tuning to get things just right 👌**.

We can fix the problem by either rotating the labels vertically (passing `rotation=90` to `ax.text`), or by using a smaller font size. Let's go for the second approach since it's usually more readable. Here, a font size of `6` is a good compromise between readability and fitting the labels within the figure. It took a couple of attempts to find the right value!

In [None]:
# Plot the results, like we did before;
plt.rcdefaults()
fig, ax = plt.subplots(nrows=1, ncols=1)

ax: Axes = sns.barplot(data=data, x="dataset", y="value", hue="model", ax=ax)

# Update the existing legend labels to have more meaningful names;
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=["Baseline", "State-of-the-art", "Our model"])

# Update the x-axis and y-axis labels;
ax.set_xlabel("Dataset")
ax.set_ylabel("Value")

# Add the arrow to denote that higher values are better;
ax = higher_is_better_arrow(ax)

# Make sure that the grid is drawn below the bars, instead of above;
ax.set_axisbelow(True)
# Add a grid to the y-axis;
ax.grid(axis="y", linestyle="--", linewidth=0.5)

# Add labels with absolute values on top of each bar;
ax = add_labels_to_bars(ax, font_size=6)

# Show the barplot;
plt.show()

### Adding relative performance labels to our bars

Now we have labels with absolute values, and it's a great step further. We know that our model has an average performance of `0.80`, which is better that the `0.49` of the baseline. By how much? `1.7X`? `1.8X`? Don't let your audience figure out the math, do it for them!

Instead of labels with absolute values, we can use labels with relative values, **measured against a baseline value**. In this case, we can use the values of the baseline model as the reference.

In [None]:
from matplotlib.container import BarContainer


def add_relative_performance_labels_to_bars(
    ax: Axes, vertical_offset_points: float = 0.5, font_size: Optional[int] = None
) -> Axes:
    # Computing relative performance is less trivial.
    # Bars are grouped in `BarContainer`, but each `BarContainer` contains
    # the bars for a given model, not for a single dataset!
    # So we need to re-group the bars so that they are grouped by dataset,
    # obtain the height of the minimum bar (the baseline), and compute the relative performance;
    containers: list[BarContainer] = ax.containers
    # From [container x bars] to [bars x container];
    bars_grouped_by_dataset: list[list[Rectangle]] = list(zip(*containers))
    relative_heights_by_group: list[list[float]] = []
    for group in bars_grouped_by_dataset:
        # Bar with the lowest value for each dataset;
        min_height = min([bar.get_height() for bar in group])
        # Compute the relative performance for each bar in the group;
        relative_heights_by_group.append([bar.get_height() / min_height for bar in group])
    # From [bars x container] to [container x bars], then flatten the list of values;
    relative_heights: list[float] = [value for group in zip(*relative_heights_by_group) for value in group]

    # Now we associate each bar to its corresponding label;
    bars: list[Rectangle] = [p for p in ax.patches if isinstance(p, Rectangle)]
    for bar, relative_height in zip(bars, relative_heights):
        # If the relative_height is 1 (i.e. the baseline), we can skip it;
        if relative_height == 1:
            continue
        label_x_coordinate = bar.get_x() + bar.get_width() / 2
        # We still need the absolute height of the bar, to place the label;
        label_y_coordinate = bar.get_height()
        # But the label is the relative performance, as +X%;
        label = f"+{relative_height - 1:.0%}"
        ax.annotate(
            text=label,  # Our label;
            xy=(label_x_coordinate, label_y_coordinate),  # Coordinates of the label, in data-coordinates;
            xytext=(0, vertical_offset_points),  # Coordinates of the text, as offset points w.r.t. `xy`;
            textcoords="offset points",
            ha="center",  # Horizontal alignment;
            va="bottom",  # Vertical alignment;
            color="#2f2f2f",  # Text color, slightly gray;
            fontsize=font_size,
        )
    return ax


# Plot the results, like we did before;
plt.rcdefaults()
fig, ax = plt.subplots(nrows=1, ncols=1)

ax: Axes = sns.barplot(data=data, x="dataset", y="value", hue="model", ax=ax)

# Update the existing legend labels to have more meaningful names;
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=["Baseline", "State-of-the-art", "Our model"])

# Update the x-axis and y-axis labels;
ax.set_xlabel("Dataset")
ax.set_ylabel("Value")

# Add the arrow to denote that higher values are better;
ax = higher_is_better_arrow(ax)

# Make sure that the grid is drawn below the bars, instead of above;
ax.set_axisbelow(True)
# Add a grid to the y-axis;
ax.grid(axis="y", linestyle="--", linewidth=0.5)

# Add labels with relative performance values on top of each bar.
# Guess what? The labels are longer, so we need a smaller font size!
ax = add_relative_performance_labels_to_bars(ax, font_size=5.5)

# Show the barplot;
plt.show()


This second approach is not necessarily better than the first! For example, if the goal is to show the improvement of our model versus the state-of-the-art model, we would not get an immediate KPI for the quality improvement, since both labels (`+49%`, `+64%`) are measured relatively to the baseline bar. 

Instead, one might want to clarify that our model is `+11%` better than the state-of-the-art. You could do so by using the **state-of-the-art values are reference** for the relative performance bars, or by adding labels **for both absolute and relative performance**. Try different strategies and see what better suits your data and message.

## Choosing the right colors

Our current plot is reasonably good. It clearly communicates a message (that our model is the best, and it's 64% better than a baseline). But we can still do more. An important aspect is to choose meaningful colors. There is a strong degree of subjectivity here, but there are some general guidelines that can help us.

### Not everyone sees colors in the same way

A lot of people are color blind, and there are many types of color blindness! In general, make sure that your colors are distinguishable by color blind people.
Also, if you are planning to add your plot to a paper, remember that many researchers print papers in black and white. Always check that your colors are distinguishable in black and white! 
If colors are too similar, try adjusting the **L** (lightness) in the **HSL** representation, or the **B** (brightness) in the **HSB** representation.

### Hidden color biases

People tend to associate implicit meanings to colors. To simplify the matter as much as possible, green is often associated to positive things, while red is bad.
Also, colors with higher saturation, and colors with very low or high brightness, usually catch the attention first.

### Crafting a color palette

How do we choose a palette? Here are some utilities made available by Seaborn and `segretini-matplottini`.

* Seaborn has a [very thorough article](https://seaborn.pydata.org/tutorial/color_palettes.html) on choosing color palettes. It's a great reference to know more about the topic. It offers many of pre-defined palettes, and most of them will satisfy the criteria above.
* Seaborn `palplot` is a simple function that plots a list of colors. It's great to quickly iterate between different paletts, until you find something that you like.
* `segretini_matplottini.utils.convert_colors_to_grayscale` converts a list of colors to grayscale. You can combine it with `palplot` to ensure that your palette will work well when printed in black and white.
* `segretini_matplottini.utils.create_hex_palette` creates a palette with the specified number of colors, interpolating between a start and end color.

Finally, [Adobe Color](https://color.adobe.com/it/create/color-wheel) is a free tool built to create palettes. It provides many pre-made palettes, and accessibility tools to test for color blindness safety.

In [None]:
from segretini_matplottini.utils import convert_colors_to_grayscale, convert_colors_to_deficiency

# Plot some pre-defined color palettes in Seaborn.
# These are continous palettes, and we can specify how many colors we want;
number_of_colors = 16
palette_names = ["flare", "viridis", "rocket", "vlag"]
for palette_name in palette_names:
    # Create the palette as a list of RGB triplets;
    palette = sns.color_palette(palette_name, n_colors=number_of_colors)
    # Plot the palette;
    sns.palplot(palette)
    # Obtain the current plot, add the name of the palette to it;
    ax: Axes = plt.gca()
    ax.set_title(f"Palette '{palette_name}' with {number_of_colors} colors")

# Since we have three bars per group, let's see how they look with just 3 colors;
number_of_colors = 3
for palette_name in palette_names:
    palette = sns.color_palette(palette_name, n_colors=number_of_colors)
    sns.palplot(palette)
    ax: Axes = plt.gca()
    ax.set_title(f"'{palette_name}', {number_of_colors} colors")

# Let's see how they look in black and white.
# See how `viridis` and `vlag` have two almost identical colors, and would not work well in B&W.
number_of_colors = 3
for palette_name in palette_names:
    palette = sns.color_palette(palette_name, n_colors=number_of_colors)
    palette = convert_colors_to_grayscale(palette)
    sns.palplot(palette)
    ax: Axes = plt.gca()
    ax.set_title(f"'{palette_name}', {number_of_colors} colors, in B&W")

# See how they look when seen by the eyes of a person affected by protanomaly, a type of color deficiency;
number_of_colors = 3
for palette_name in palette_names:
    palette = sns.color_palette(palette_name, n_colors=number_of_colors)
    palette = convert_colors_to_deficiency(palette, deficiency="protanomaly")
    sns.palplot(palette)
    ax: Axes = plt.gca()
    ax.set_title(f"'{palette_name}', {number_of_colors} colors, protanomaly")

For our barplot, we can select colors by hand, since we need just three of them.
Since the first bar in each group is the baseline, we can give it a color that's mild but also different from the other bars, so that it doesn't draw too much attention.
On the other hand, we want to give our model the most visible and striking color, so that it stands out from the other bars.

The following colors are taken from the `PALETTE_ORANGE_BASELINE_AND_GREEN_TONES` palette in `segretini_matplottini.utils.colors` module, which provides a few palettes and colors that are ready to use.  It looks good, and it works well in black and white too.

In [None]:
palette = ["#E8CFB5", "#81C798", "#358787"]
sns.palplot(palette)
palette_bw = convert_colors_to_grayscale(palette)
sns.palplot(palette_bw)
palette_protanomaly = convert_colors_to_deficiency(palette, "protanomaly", 1)
sns.palplot(palette_protanomaly)
palette_deuteranomaly = convert_colors_to_deficiency(palette, "deuteranomaly", 1)
sns.palplot(palette_deuteranomaly)
palette_tritanomaly = convert_colors_to_deficiency(palette, "tritanomaly", 1)
sns.palplot(palette_tritanomaly)

Let's apply the new palette to our barplot, using the `palette` attribute of `barplot`. The number of colors has to match the number of bars per group.
Our barplot will immediately look much better. 🎨

> 💡 Seaborn applies a saturation of `0.8` to the color palette that we specify. To make the colors pop a bit more, we set a saturation of `0.9`.

While we are at it, we can create a function to plot the current barplot, to avoid repeating the same code over and over.

In [None]:
def barplot(data: pd.DataFrame, palette: list[str], ax: Axes) -> Axes:
    ax: Axes = sns.barplot(data=data, x="dataset", y="value", hue="model", palette=palette, ax=ax, saturation=0.9)

    # Update the legend labels;
    handles, _ = ax.get_legend_handles_labels()
    ax.legend(handles=handles, labels=["Baseline", "State-of-the-art", "Our model"])

    # Update the x-axis and y-axis labels;
    ax.set_xlabel("Dataset")
    ax.set_ylabel("Value")

    # Add the arrow to denote that higher values are better;
    ax = higher_is_better_arrow(ax)

    # Make sure that the grid is drawn below the bars, instead of above;
    ax.set_axisbelow(True)
    # Add a grid to the y-axis;
    ax.grid(axis="y", linestyle="--", linewidth=0.5)

    # Add labels with relative performance values on top of each bar.
    ax = add_relative_performance_labels_to_bars(ax, font_size=5.5)

    return ax


# Plot the results, like we did before;
plt.rcdefaults()
fig, ax = plt.subplots(nrows=1, ncols=1)

# Main plot;
ax = barplot(data=data, palette=palette, ax=ax)

# Show the barplot;
plt.show()


## Adjusting the order of bar groups

You might have noticed how the order of the bar groups in our current plot does not make a lot of sense, with `dataset_10` appearing before `dataset_1`. We can fix this as follows.

In [None]:
# Put the data loading logic it into a function;
def load_data() -> pd.DataFrame:
    data = pd.read_csv("../data/notebooks/1_getting_started_with_barplots/barplot_data.csv")

    # Order the dataset names by their suffix, which is treated as integer;
    data = data.sort_values(by="dataset", key=lambda x: x.str.split("_").str[1].astype(int))

    # Add the average results;
    average_results: pd.DataFrame = data.groupby("model").mean(numeric_only=True)
    average_results = average_results.reset_index()
    average_results["dataset"] = "Average"
    data = pd.concat([average_results, data])
    return data


# Load the data
data: pd.DataFrame = load_data()

# Plot the results, like we did before;
plt.rcdefaults()
fig, ax = plt.subplots(nrows=1, ncols=1)

# Main plot;
ax = barplot(data=data, palette=palette, ax=ax)

# Show the barplot;
plt.show()


It worked! 🧑‍🔬 Another alternative, which works for any arbitrary order (not just the order obtained with the suffixes), is to turn the x-axis label **categorical**, and plot the bar groups by preserving the order of the categories. For example, we could reorder the bar groups using the baseline value, from low to high.

Note that by sorting the labels in this way, the legend would overlap with the last bar group. As such, we move it by hand so that it is in a better position.

While we are at it, let's also add custom labels for the x-axis ticks.

In [None]:
from matplotlib.text import Text


# Put the data loading logic it into a function;
def load_data() -> pd.DataFrame:
    data = pd.read_csv("../data/notebooks/1_getting_started_with_barplots/barplot_data.csv")

    # Obtain the datasets sorted by baseline value, from low to high;
    sorted_categories = data[data["model"] == "A"].sort_values(by="value")["dataset"].unique().tolist()

    # Add the average results;
    average_results: pd.DataFrame = data.groupby("model").mean(numeric_only=True)
    average_results = average_results.reset_index()
    average_results["dataset"] = "Average"
    data = pd.concat([average_results, data])

    # Create categorical variables for the dataset column.
    # We do this transformation at the end, so that we have already added the "Average" category.
    # Let's ensure that the "Average" is still the first category;
    data["dataset"] = pd.Categorical(data["dataset"], categories=["Average"] + sorted_categories)

    return data


# Load the data
data: pd.DataFrame = load_data()

# Plot the results, like we did before;
plt.rcdefaults()
fig, ax = plt.subplots(nrows=1, ncols=1)

# Main plot;
ax = barplot(data=data, palette=palette, ax=ax)

# Move the legend to the top left, so it does not overlap with the last bar group.
# Since the top left has the arrow, we manually
# specify the position of the legend to avoid overlapping with the arrow.
# With a bit of trial and error, we locate it at 5% of the x-axis width, and 100% of the y-axis height.
sns.move_legend(ax, loc="upper left", bbox_to_anchor=(0.05, 1))

# Update the x-axis labels;
x_tick_labels: list[Text] = ax.get_xticklabels()
ax.set_xticklabels([x.get_text().replace("_", " ").title() for x in x_tick_labels])

# Show the barplot;
plt.show()


## Adjusting the plot size

In most situations where you create a plot, **you are bound by space constraints**, whether we are talking about a research paper, a poster, or a slide deck. You need to make efficient use of this space, avoiding unnecessary whitespace, while ensuring that everything is readable.

Adjusting the plot size and its elements to make efficient use of the space is a critical step. When creating a plot from scratch, **start with a reasonable figure size**, then adjust the size of the elements as needed. This is faster than creating a beautiful plot without caring about its size, and then realizing that you need to redo everything because you will never be able to shrink the figure to the space you are given without making it unreadable.

### 📐 Figure size, DPI, and PPI

When talking about plot sizes, the following concepts are relevant.

* The **figure size**, expressed in inches, is how much space is given to axes and other elements. If you create a figure whose size is `2 x 2` inches, it will be `2 x 2` inches when you save it. The default value in Matplotlib is `6.4 x 4.8` inches.
* The **DPI** (dots per inch) says how many pixels are in an inch. If you create a figure whose size is `2 x 2` inches, and the DPI is `100`, the figure will be `200 x 200` pixels when you save it. **The default DPI is 100**. To get a crisp-looking image, it is better to use a higher DPI, such as `300` or `600`.
* Some elements, such as text and lines, have their size expressed in points. The **PPI** (points per inch) in Matplotlib is `72` ([source](https://matplotlib.org/stable/tutorials/advanced/transforms_tutorial.html#using-offset-transforms-to-create-a-shadow-effect)). If you increase the figure size, you will also have to increase the font and lines size, to keep the same relative size. If you change the DPI, you will just increase the resolution, but the relative size of elements is unchanged.

Matplotlib default values (figure size `6.4 x 4.8` inches, DPI `100`, PPI `72`) are good to create plots that are readable on a screen. But if you plan to use your plot on a paper, a poster, or something that is not a screen, you will need to adjust the figure size and DPI

Seaborn provides a `set_context` function ([link](https://seaborn.pydata.org/generated/seaborn.set_context.html)) that has some rules to scale the elements for contexts such as `notebook`, `paper`, `poster`. These are better than the default Matplotlib settings, but might not be enough for more advanced needs.

For reference, LaTeX has also a PPI of 72, since it is a typographical standard. A two-column paper template has a column width of around 252 points, or about 3.5 inches when considering a PPI of 72. A one-column template has a size of 395 points or about 5.5 inches. The font size of a typical research paper is something between 8 and 12, with 10 being the most common. These are also reasonable sizes for textual elements in your plots.

`segretini-matplottini` uses `3.5 inches` as the default figure size for most plots, with a font size of `8`. Even if you are not creating a plot for a research paper, that's a reasonable figure size that ensures good readability.

### 🏗️ Blueprint for adjusting the plot size

If you are creating a plot from scratch, this blueprint works well.
1. Start with a figure width that is reasonable for your use case. For example, `3.5 inches` for a 2-column LaTeX paper.
2. Start with the same width and height (e.g. `3.5 inches`), unless you have a prior that a different height will be better. For example, `1.96 inches` if you are creating a plot for a `16:9` slide deck.
3. Set a font size of `8`.
4. **Create your plot.**
5. **Adjust the left and right margins** with `plt.subplots_adjust()`. Also, the `wspace` (horizontal space between subplots), if you have a grid of plots.
6. **Optimize the vertical space.** Usually, that means reducing the vertical space as much as possible, to save on vertical space for other text or plots. Start by reducing the figure height, then adjust the top and bottom margins with `plt.subplots_adjust()`. Also, the `hspace` (vertical space between subplots), if you have a grid of plots.
7. If necessary, you can reduce the font size of elements such as tick labels. If you are creating a plot for a research paper, usually a font size as small as `5` or `6` will still be readable.

### 💡 Advanced tips for plotting in notebooks

From this point onward, all cells that contain plots are affected by `set_matplotlib_formats("retina", bbox_inches=None)`. 

The setting `retina` increases the resolution of the plots rendered within the notebook. The setting doubles the DPI when rendering the plot while preserving the original height and width of the visualized plot (technically, the height and width are halved before rendering it). 

Instead, `bbox_inches=None` **displays plots in the notebook exactly as they would be saved to a file.** 
Without it, Matplotlib is going to adjust margins automatically before displaying the plot, using the `tight` layout. While the `tight` layout gives a prettier result in the notebook, it is inconsistent with what we would save to a file. It also does not allow fine-tuning the margins with the same degree of precision, since adjustments are overridden by the `tight` layout.

> ⚠️ Changes to `set_matplotlib_formats` are not overridden by `plt.rcdefaults()`. If you need to restore the default settings, use `set_matplotlib_formats("png", bbox_inches="tight")`

In [None]:
# Increase quality of plots in the notebook,
# and prevent automatic adjustments to the margins;
from matplotlib_inline.backend_inline import set_matplotlib_formats

set_matplotlib_formats("retina", bbox_inches=None)

# Load the data
data: pd.DataFrame = load_data()

# Set the width to 3.5 inches. The height is set to have a 16:9 aspect ratio.
# since we know that the plot benefits more from a wider layout;
figure_size: tuple[float, float] = (3.5, 3.5 * 9 / 16)

# Create a figure with the specified size;
plt.rcdefaults()
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figure_size)

# Main plot;
ax = barplot(data=data, palette=palette, ax=ax)

# Move the legend to the top left;
sns.move_legend(ax, loc="upper left", bbox_to_anchor=(0.05, 1))

# Update the x-axis labels;
x_tick_labels: list[Text] = ax.get_xticklabels()
ax.set_xticklabels([x.get_text().replace("_", " ").title() for x in x_tick_labels])

# Show the barplot;
plt.show()

The new figure is tiny! That's because the default figure size in Matplotlib is **6.4 inches for the width and 4.8 inches for the height**, with a **DPI of 100**.
If we want to maintain the same figure size but scale the visualization in the notebook, we can **increase the DPI to 300**. 

In [None]:
# Create a figure with the specified size;
plt.rcdefaults()
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figure_size, dpi=200)

# Main plot;
ax = barplot(data=data, palette=palette, ax=ax)

# Show the barplot;
plt.show()


The result has a reasonable size, but now elements are overlapped and the plot is unreadable. Compared to the original figure (before changing the figure size), there is less wasted space, but now the text is too large. Let's start by changing the **font size to 8**.
We will also need to adjust the margins since the x-axis and y-axis labels are being cut.

> 💡 Now, we draw the legend directly in the right location, instead of moving it later. The reason is that `sns.move_legend` does not move the legend, but redraws it from scratch. Redrawing the legend ignores the font size that we specified. If you need to move the legend without redrawing it, you can use `ax.get_legend().set_bbox_to_anchor((0.05, 1.0))`. This second function does not change the coordinates' origin (i.e. `upper left`), so you need to tune the coordinates to get the placement right.

In [None]:
# Update the standard plotting function to take the font size as input parameter,
# and also put the other changes we did (legend and x-axis labels) in there;
def barplot(data: pd.DataFrame, palette: list[str], font_size: int, ax: Axes) -> Axes:
    ax: Axes = sns.barplot(data=data, x="dataset", y="value", hue="model", palette=palette, ax=ax, saturation=0.9)

    # Update the legend labels;
    handles, _ = ax.get_legend_handles_labels()
    # Place the legend directly in the right position, instead of moving it later;
    ax.legend(
        handles=handles,
        labels=["Baseline", "State-of-the-art", "Our model"],
        fontsize=font_size,
        loc="upper left",
        bbox_to_anchor=(0.05, 1.0),
    )

    # Update the x-axis and y-axis labels;
    ax.set_xlabel("Dataset", fontsize=font_size)
    ax.set_ylabel("Value", fontsize=font_size)

    # Add the arrow to denote that higher values are better;
    ax = higher_is_better_arrow(ax)

    # Make sure that the grid is drawn below the bars, instead of above;
    ax.set_axisbelow(True)
    # Add a grid to the y-axis;
    ax.grid(axis="y", linestyle="--", linewidth=0.5)

    # The font size of the bar labels is also parametrized by the input font size
    ax = add_relative_performance_labels_to_bars(ax, font_size=font_size * 0.4)

    # Update the x-axis labels;
    x_tick_labels: list[Text] = ax.get_xticklabels()
    ax.set_xticklabels([x.get_text().replace("_", " ").title() for x in x_tick_labels])

    # Set the font size of the x-axis and y-axis tick labels;
    ax.tick_params(axis="both", labelsize=font_size)

    return ax


# Create a figure with the specified size;
plt.rcdefaults()
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figure_size, dpi=200)

# Main plot;
ax = barplot(data=data, palette=palette, font_size=8, ax=ax)

# Show the barplot;
plt.show()


We are getting there! 💪 As the next steps, let's reduce the font size of the x-axis and y-axis tick labels, and tune the margins to reduce unnecessary white space around the plot. 

Tuning the margins takes a lot of trial and error to get it right. Here we'll directly use good values, but in practice, it would take a few attempts before settling on something that looks good.

In [None]:
def barplot(data: pd.DataFrame, palette: list[str], font_size: int, ax: Axes) -> Axes:
    ax: Axes = sns.barplot(data=data, x="dataset", y="value", hue="model", palette=palette, ax=ax, saturation=0.9)

    # Update the legend labels;
    handles, _ = ax.get_legend_handles_labels()
    # Place the legend directly in the right position, instead of moving it later;
    ax.legend(
        handles=handles,
        labels=["Baseline", "State-of-the-art", "Our model"],
        fontsize=font_size,
        loc="upper left",
        bbox_to_anchor=(0.05, 1.0),
    )

    # Update the x-axis and y-axis labels;
    ax.set_xlabel("Dataset", fontsize=font_size)
    # Here we also reduce the `labelpad`, the space between the y-axis label and the y-axis ticks;
    ax.set_ylabel("Value", fontsize=font_size, labelpad=1)

    # Add the arrow to denote that higher values are better;
    ax = higher_is_better_arrow(ax)

    # Make sure that the grid is drawn below the bars, instead of above;
    ax.set_axisbelow(True)
    # Add a grid to the y-axis;
    ax.grid(axis="y", linestyle="--", linewidth=0.5)

    # The font size of the bar labels is also parametrized by the input font size
    ax = add_relative_performance_labels_to_bars(ax, font_size=font_size * 0.4)

    # Update the x-axis labels;
    x_tick_labels: list[Text] = ax.get_xticklabels()
    ax.set_xticklabels([x.get_text().replace("_", " ").title() for x in x_tick_labels])

    # Set the font size of the x-axis and y-axis tick labels;
    ax.tick_params(axis="both", labelsize=font_size * 0.7)

    return ax


# Create a figure with the specified size;
plt.rcdefaults()
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figure_size, dpi=200)

# Adjust the margins of the figure;
fig.subplots_adjust(left=0.11, right=0.99, bottom=0.17, top=0.98)

# Main plot;
ax = barplot(data=data, palette=palette, font_size=8, ax=ax)

# Show the barplot;
plt.show()

## 💾 Saving your plots

We are finally at a point where our plot looks great, and we are ready to save it. To do so, you can use `fig.savefig(...)` and specify the path to the file where the plot is saved. Keep the following in mind.
* If your goal is to put your plot in a LaTeX document, save the plot as a PDF to get a lossless vector format. You can also use SVG, for example, if you need to show the plot on a web page.
* If you prefer a raster format, use PNG (e.g. for slide decks). In this case, do not forget to specify the DPI of the output file. This DPI is unrelated to the DPI that you specify in `plt.subplots(...)`. Usually, a DPI of 300 works well. If you have small textual elements, increase them as much as you need.

We are going to save our plot as both PDF and PNG.

> 💡 The cell below contains the full code to create the plot from scratch so that the cell can be run independently.

In [None]:
from matplotlib_inline.backend_inline import set_matplotlib_formats
from matplotlib.patches import Rectangle
from matplotlib.container import BarContainer
from matplotlib.text import Text
from matplotlib.axes import Axes
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from typing import Optional

#############
# Functions #
#############


# Put the data loading logic it into a function;
def load_data() -> pd.DataFrame:
    data = pd.read_csv("../data/notebooks/1_getting_started_with_barplots/barplot_data.csv")

    # Obtain the datasets sorted by baseline value, from low to high;
    sorted_categories = data[data["model"] == "A"].sort_values(by="value")["dataset"].unique().tolist()

    # Add the average results;
    average_results: pd.DataFrame = data.groupby("model").mean(numeric_only=True)
    average_results = average_results.reset_index()
    average_results["dataset"] = "Average"
    data = pd.concat([average_results, data])

    # Create categorical variables for the dataset column.
    # We do this transformation at the end, so that we have already added the "Average" category.
    # Let's ensure that the "Average" is still the first category;
    data["dataset"] = pd.Categorical(data["dataset"], categories=["Average"] + sorted_categories)

    return data


def higher_is_better_arrow(ax: Axes, linewidth: float = 1) -> Axes:
    # Add whitespace to the left of the leftmost bar;
    ax.set_xlim(ax.get_xlim()[0] - 0.3, ax.get_xlim()[1])
    # Get the x-axis coordinates in data coordinates,
    # and the y-axis coordinates in axis coordinates;
    first_bar: Rectangle = ax.patches[0]
    x_coord: float = first_bar.get_x() - (first_bar.get_x() - ax.get_xlim()[0]) / 2
    y_start: float = 0.01
    y_end: float = 0.99
    # Draw the arrow;
    ax.annotate(
        "",
        xy=(x_coord, y_end),
        xytext=(x_coord, y_start),
        arrowprops=dict(
            arrowstyle="->",
            color="#2f2f2f",
            linewidth=linewidth,
        ),
        xycoords=ax.get_xaxis_transform(),
    )
    return ax


def add_relative_performance_labels_to_bars(
    ax: Axes, vertical_offset_points: float = 0.5, font_size: Optional[int] = None
) -> Axes:
    # Bars are grouped in `BarContainer`, but each `BarContainer` contains
    # the bars for a given model, not for a single dataset!
    # So we need to re-group the bars so that they are grouped by dataset,
    # obtain the height of the minimum bar (the baseline), and compute the relative performance;
    containers: list[BarContainer] = ax.containers
    # From [container x bars] to [bars x container];
    bars_grouped_by_dataset: list[list[Rectangle]] = list(zip(*containers))
    relative_heights_by_group: list[list[float]] = []
    for group in bars_grouped_by_dataset:
        # Bar with the lowest value for each dataset;
        min_height = min([bar.get_height() for bar in group])
        # Compute the relative performance for each bar in the group;
        relative_heights_by_group.append([bar.get_height() / min_height for bar in group])
    # From [bars x container] to [container x bars], then flatten the list of values;
    relative_heights: list[float] = [value for group in zip(*relative_heights_by_group) for value in group]
    # Associate each bar to its corresponding label;
    bars: list[Rectangle] = [p for p in ax.patches if isinstance(p, Rectangle)]
    for bar, relative_height in zip(bars, relative_heights):
        # If the relative_height is 1 (i.e. the baseline), we can skip it;
        if relative_height == 1:
            continue
        label_x_coordinate = bar.get_x() + bar.get_width() / 2
        label_y_coordinate = bar.get_height()
        label = f"+{relative_height - 1:.0%}"
        ax.annotate(
            text=label,
            xy=(label_x_coordinate, label_y_coordinate),  # Coordinates of the label, in data-coordinates;
            xytext=(0, vertical_offset_points),  # Coordinates of the text, as offset points w.r.t. `xy`;
            textcoords="offset points",
            ha="center",
            va="bottom",
            color="#2f2f2f",
            fontsize=font_size,
        )
    return ax


def barplot(data: pd.DataFrame, palette: list[str], font_size: int, ax: Axes) -> Axes:
    # Main barplot;
    ax: Axes = sns.barplot(data=data, x="dataset", y="value", hue="model", palette=palette, ax=ax, saturation=0.9)

    # Update the legend labels;
    handles, _ = ax.get_legend_handles_labels()
    ax.legend(
        handles=handles,
        labels=["Baseline", "State-of-the-art", "Our model"],
        fontsize=font_size,
        loc="upper left",
        bbox_to_anchor=(0.05, 1.0),
    )

    # Update the x-axis and y-axis labels and settings;
    ax.set_xlabel("Dataset", fontsize=font_size)
    ax.set_ylabel("Value", fontsize=font_size, labelpad=1)

    # Add the arrow to denote that higher values are better;
    ax = higher_is_better_arrow(ax)

    # Add a grid to the y-axis. Make sure that the grid is drawn below the bars, instead of above;
    ax.set_axisbelow(True)
    ax.grid(axis="y", linestyle="--", linewidth=0.5)

    # Add labels to bars;
    ax = add_relative_performance_labels_to_bars(ax, font_size=font_size * 0.4)

    # Update the x-axis tick labels;
    x_tick_labels: list[Text] = ax.get_xticklabels()
    ax.set_xticklabels([x.get_text().replace("_", " ").title() for x in x_tick_labels])

    # Set the font size of the x-axis and y-axis tick labels;
    ax.tick_params(axis="both", labelsize=font_size * 0.7)

    return ax


############
# Plotting #
############

# Increase quality of plots in the notebook,
# and prevent automatic adjustments to the margins;
set_matplotlib_formats("retina", bbox_inches=None)
plt.rcdefaults()

# Load the data
data: pd.DataFrame = load_data()

# Set the width to 3.5 inches, and a 16:9 aspect ratio;
figure_size: tuple[float, float] = (3.5, 3.5 * 9 / 16)

# Color palette for the bars;
palette = ["#E8CFB5", "#81C798", "#358787"]

# Create a figure with the specified size;
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figure_size, dpi=200)

# Adjust the margins of the figure;
fig.subplots_adjust(left=0.11, right=0.99, bottom=0.17, top=0.98)

# Main plot;
ax = barplot(data=data, palette=palette, font_size=8, ax=ax)

# Save the plot as PDF and PNG. This must be done before calling `plt.show()`,
# since `plt.show()` clears the figure.
# The `dpi=300` in `savefig` is not affected by the `dpi=200` in `plt.subplots`;
plt.savefig("../plots/notebooks/1_getting_started_with_barplots/our_amazing_barplot.pdf")
plt.savefig("../plots/notebooks/1_getting_started_with_barplots/our_amazing_barplot.png", dpi=300)

# Show the barplot;
plt.show()

# 🏁 Final touches

Obsession with tiny details is what makes the difference between good and great! Our current plot is in great shape, but there are still a few minor changes we can make. Most of these come with a good amount of subjectivity and personal preference. If you don't like the look of a specific change, feel free to ignore it! But knowing that these things are possible is a good starting point.

First, we will do the following. Changes in the code below are marked with the 🌞 emoji.
1. Add a contour to the bars, to make them stand out more.
2. Separate the "Average" bar group from the other groups with a vertical line, to make it stand out more.
3. Modify the bar label format, from `+49%` to `1.49X`. While the first format is arguably clearer, it was causing some overlap between labels and bars. The second format is more compact. Reduce their size a bit, as well.
4. Make the arrow thinner, since now it is a bit too thick.
5. Reduce the whitespace added to the left bar, required to add the vertical line.
6. Control the thickness of the axes' borders, and of the ticks. They must have the same thickness, to look good.
7. Manually set the y-axis limits.
8. Control the position and number of y-axis ticks, so that the highest tick denotes the y-axis upper limit.
9. Control the formatting of the y-axis ticks, so they are displayed with two decimal places.
10. Reduce the padding between ticks and tick labels, so that they look less "floaty".
11. As a consequence of point `10`, we can reduce further the plot bottom margin.
12. Set Arial as the font for the plot.
13. Reduce the distance between the x-axis label and the x-axis.

> 💡 The cells below contain the full code to create the plot from scratch so that the cell can be run independently.

In [None]:
from matplotlib_inline.backend_inline import set_matplotlib_formats
from matplotlib.patches import Rectangle
from matplotlib.container import BarContainer
from matplotlib.text import Text
from matplotlib.axes import Axes
from matplotlib.ticker import LinearLocator
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from typing import Optional

#############
# Functions #
#############


# Put the data loading logic it into a function;
def load_data() -> pd.DataFrame:
    data = pd.read_csv("../data/notebooks/1_getting_started_with_barplots/barplot_data.csv")

    # Obtain the datasets sorted by baseline value, from low to high;
    sorted_categories = data[data["model"] == "A"].sort_values(by="value")["dataset"].unique().tolist()

    # Add the average results;
    average_results: pd.DataFrame = data.groupby("model").mean(numeric_only=True)
    average_results = average_results.reset_index()
    average_results["dataset"] = "Average"
    data = pd.concat([average_results, data])

    # Create categorical variables for the dataset column.
    # We do this transformation at the end, so that we have already added the "Average" category.
    # Let's ensure that the "Average" is still the first category;
    data["dataset"] = pd.Categorical(data["dataset"], categories=["Average"] + sorted_categories)

    return data


def higher_is_better_arrow(ax: Axes, linewidth: float = 1) -> Axes:
    # Add whitespace to the left of the leftmost bar;
    ax.set_xlim(ax.get_xlim()[0] - 0.2, ax.get_xlim()[1])  # 🌞 Reduce the space added to the left from 0.3;
    # Get the x-axis coordinates in data coordinates,
    # and the y-axis coordinates in axis coordinates;
    first_bar: Rectangle = ax.patches[0]
    x_coord: float = first_bar.get_x() - (first_bar.get_x() - ax.get_xlim()[0]) / 2
    y_start: float = 0.01
    y_end: float = 0.99
    # Draw the arrow;
    ax.annotate(
        "",
        xy=(x_coord, y_end),
        xytext=(x_coord, y_start),
        arrowprops=dict(
            arrowstyle="->",
            color="#2f2f2f",
            linewidth=linewidth,
        ),
        xycoords=ax.get_xaxis_transform(),
    )
    return ax


def add_relative_performance_labels_to_bars(
    ax: Axes, vertical_offset_points: float = 0.5, font_size: Optional[int] = None
) -> Axes:
    # Bars are grouped in `BarContainer`, but each `BarContainer` contains
    # the bars for a given model, not for a single dataset!
    # So we need to re-group the bars so that they are grouped by dataset,
    # obtain the height of the minimum bar (the baseline), and compute the relative performance;
    containers: list[BarContainer] = ax.containers
    # From [container x bars] to [bars x container];
    bars_grouped_by_dataset: list[list[Rectangle]] = list(zip(*containers))
    relative_heights_by_group: list[list[float]] = []
    for group in bars_grouped_by_dataset:
        # Bar with the lowest value for each dataset;
        min_height = min([bar.get_height() for bar in group])
        # Compute the relative performance for each bar in the group;
        relative_heights_by_group.append([bar.get_height() / min_height for bar in group])
    # From [bars x container] to [container x bars], then flatten the list of values;
    relative_heights: list[float] = [value for group in zip(*relative_heights_by_group) for value in group]
    # Associate each bar to its corresponding label;
    bars: list[Rectangle] = [p for p in ax.patches if isinstance(p, Rectangle)]
    for bar, relative_height in zip(bars, relative_heights):
        # If the relative_height is 1 (i.e. the baseline), we can skip it;
        if relative_height == 1:
            continue
        label_x_coordinate = bar.get_x() + bar.get_width() / 2
        label_y_coordinate = bar.get_height()
        label = f"{relative_height:.2f}X"  # 🌞 Change the bar label format;
        ax.annotate(
            text=label,
            xy=(label_x_coordinate, label_y_coordinate),  # Coordinates of the label, in data-coordinates;
            xytext=(0, vertical_offset_points),  # Coordinates of the text, as offset points w.r.t. `xy`;
            textcoords="offset points",
            ha="center",
            va="bottom",
            color="#2f2f2f",
            fontsize=font_size,
        )
    return ax


def barplot(data: pd.DataFrame, palette: list[str], font_size: int, ax: Axes) -> Axes:
    # Main barplot;
    ax: Axes = sns.barplot(
        data=data,
        x="dataset",
        y="value",
        hue="model",
        palette=palette,
        ax=ax,
        saturation=0.9,
        linewidth=0.5,  # 🌞 Set the width of the bar edges;
        edgecolor="#2f2f2f",  # 🌞 Set the color of the bar edges;
    )

    # Update the legend labels;
    handles, _ = ax.get_legend_handles_labels()
    ax.legend(
        handles=handles,
        labels=["Baseline", "State-of-the-art", "Our model"],
        fontsize=font_size,
        loc="upper left",
        bbox_to_anchor=(0.05, 1.0),
    )

    # Update the x-axis and y-axis labels and settings;
    ax.set_xlabel(
        "Dataset", fontsize=font_size, labelpad=2  # 🌞 Decrease the distance between the x-axis label and the x-axis;
    )
    ax.set_ylabel("Value", fontsize=font_size, labelpad=1)

    # Add the arrow to denote that higher values are better;
    ax = higher_is_better_arrow(
        ax,
        linewidth=0.6,  # 🌞 Make the arrow thinner;
    )

    # Add a grid to the y-axis. Make sure that the grid is drawn below the bars, instead of above;
    ax.set_axisbelow(True)
    ax.grid(axis="y", linestyle="--", linewidth=0.5)

    # Add labels to bars;
    ax = add_relative_performance_labels_to_bars(
        ax, font_size=font_size * 0.35  # 🌞 Size of bar labels reduced from 0.4;
    )

    # Update the x-axis tick labels;
    x_tick_labels: list[Text] = ax.get_xticklabels()
    ax.set_xticklabels([x.get_text().replace("_", " ").title() for x in x_tick_labels])

    # 🌞 Add a vertical line to separate the "Average" category from the other categories;
    ax.axvline(x=0.5, color="#2f2f2f", linewidth=0.5, linestyle="--")

    # 🌞 Set the y-axis limits;
    ax.set_ylim(0, 1.75)
    # 🌞 Set the y-axis ticks so that we have a tick in correspondence of the y-axis limits;
    ax.yaxis.set_major_locator(LinearLocator(8))
    # 🌞 Set the y-axis tick labels to have 2 decimal places;
    ax.yaxis.set_major_formatter(lambda x, pos: f"{x:.2f}")

    # Set the font size of the x-axis and y-axis tick labels;
    ax.tick_params(
        axis="both",
        labelsize=font_size * 0.7,
        pad=1,  # 🌞 Reduce the padding of tick labels;
    )

    return ax


############
# Plotting #
############

# Increase quality of plots in the notebook,
# and prevent automatic adjustments to the margins;
set_matplotlib_formats("retina", bbox_inches=None)
plt.rcdefaults()

# 🌞 Increase the width of the axes border, and of the ticks.
# We do so by changing the global parameter;
plt.rcParams["axes.linewidth"] = 0.7
plt.rcParams["xtick.major.width"] = 0.7
plt.rcParams["ytick.major.width"] = 0.7

# 🌞 Set Arial as font
plt.rcParams["font.family"] = "Arial"

# Load the data
data: pd.DataFrame = load_data()

# Set the width to 3.5 inches, and a 16:9 aspect ratio;
figure_size: tuple[float, float] = (3.5, 3.5 * 9 / 16)

# Color palette for the bars;
palette = ["#E8CFB5", "#81C798", "#358787"]

# Create a figure with the specified size;
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figure_size, dpi=200)

# Adjust the margins of the figure;
fig.subplots_adjust(
    left=0.11,
    right=0.99,
    bottom=0.16,  # 🌞 Reduce the bottom margin from 0.17;
    top=0.97,  # 🌞 Reduce the top margin from 0.98, to make space for the highest y-axis tick label;
)

# Main plot;
ax = barplot(data=data, palette=palette, font_size=8, ax=ax)

# Save the plot as PDF and PNG;
plt.savefig("../plots/notebooks/1_getting_started_with_barplots/our_amazing_barplot_v2.pdf")
plt.savefig("../plots/notebooks/1_getting_started_with_barplots/our_amazing_barplot_v2.png", dpi=300)

# Show the barplot;
plt.show()

### Customizing the legend

Last but not least, let's customize the legend of the plot. The current legend works great, and there's no reason to modify it.
But for the sake of it, let's see how we can adjust its style and position.

We will do the following.
1. Make the legend horizontal.
2. Place it outside the axes, below the plot.
3. Customize the border of the legend.
4. Reduce the width of the legend handles, and the rectangles that correspond to each category.

And that's it! This is how you can create an amazing barplot, and customize even its smallest details. 👏

> 💡 Plots in `segretini-matplottini` have a custom legend that contains a shadow below the legend, with opacity `1` and no blurring. By default, Matplotlib uses a shadow with a slight blur and transparency. Achieving the same look as `segretini-matplottini` demands some low-level changes to how the legend is rendered and goes outside the scope of this notebook. However, adding such a legend is simple! Just use `segretini_matplottini.utils.add_legend_with_dark_shadow`, passing the legend handles and labels to it.

In [None]:
from matplotlib_inline.backend_inline import set_matplotlib_formats
from matplotlib.patches import Rectangle
from matplotlib.container import BarContainer
from matplotlib.text import Text
from matplotlib.axes import Axes
from matplotlib.ticker import LinearLocator
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from typing import Optional

#############
# Functions #
#############


# Put the data loading logic it into a function;
def load_data() -> pd.DataFrame:
    data = pd.read_csv("../data/notebooks/1_getting_started_with_barplots/barplot_data.csv")

    # Obtain the datasets sorted by baseline value, from low to high;
    sorted_categories = data[data["model"] == "A"].sort_values(by="value")["dataset"].unique().tolist()

    # Add the average results;
    average_results: pd.DataFrame = data.groupby("model").mean(numeric_only=True)
    average_results = average_results.reset_index()
    average_results["dataset"] = "Average"
    data = pd.concat([average_results, data])

    # Create categorical variables for the dataset column.
    # We do this transformation at the end, so that we have already added the "Average" category.
    # Let's ensure that the "Average" is still the first category;
    data["dataset"] = pd.Categorical(data["dataset"], categories=["Average"] + sorted_categories)

    return data


def higher_is_better_arrow(ax: Axes, linewidth: float = 1) -> Axes:
    # Add whitespace to the left of the leftmost bar;
    ax.set_xlim(ax.get_xlim()[0] - 0.2, ax.get_xlim()[1])
    # Get the x-axis coordinates in data coordinates,
    # and the y-axis coordinates in axis coordinates;
    first_bar: Rectangle = ax.patches[0]
    x_coord: float = first_bar.get_x() - (first_bar.get_x() - ax.get_xlim()[0]) / 2
    y_start: float = 0.01
    y_end: float = 0.99
    # Draw the arrow;
    ax.annotate(
        "",
        xy=(x_coord, y_end),
        xytext=(x_coord, y_start),
        arrowprops=dict(
            arrowstyle="->",
            color="#2f2f2f",
            linewidth=linewidth,
        ),
        xycoords=ax.get_xaxis_transform(),
    )
    return ax


def add_relative_performance_labels_to_bars(
    ax: Axes, vertical_offset_points: float = 0.5, font_size: Optional[int] = None
) -> Axes:
    # Bars are grouped in `BarContainer`, but each `BarContainer` contains
    # the bars for a given model, not for a single dataset!
    # So we need to re-group the bars so that they are grouped by dataset,
    # obtain the height of the minimum bar (the baseline), and compute the relative performance;
    containers: list[BarContainer] = ax.containers
    # From [container x bars] to [bars x container];
    bars_grouped_by_dataset: list[list[Rectangle]] = list(zip(*containers))
    relative_heights_by_group: list[list[float]] = []
    for group in bars_grouped_by_dataset:
        # Bar with the lowest value for each dataset;
        min_height = min([bar.get_height() for bar in group])
        # Compute the relative performance for each bar in the group;
        relative_heights_by_group.append([bar.get_height() / min_height for bar in group])
    # From [bars x container] to [container x bars], then flatten the list of values;
    relative_heights: list[float] = [value for group in zip(*relative_heights_by_group) for value in group]
    # Associate each bar to its corresponding label;
    bars: list[Rectangle] = [p for p in ax.patches if isinstance(p, Rectangle)]
    for bar, relative_height in zip(bars, relative_heights):
        # If the relative_height is 1 (i.e. the baseline), we can skip it;
        if relative_height == 1:
            continue
        label_x_coordinate = bar.get_x() + bar.get_width() / 2
        label_y_coordinate = bar.get_height()
        label = f"{relative_height:.2f}X"
        ax.annotate(
            text=label,
            xy=(label_x_coordinate, label_y_coordinate),  # Coordinates of the label, in data-coordinates;
            xytext=(0, vertical_offset_points),  # Coordinates of the text, as offset points w.r.t. `xy`;
            textcoords="offset points",
            ha="center",
            va="bottom",
            color="#2f2f2f",
            fontsize=font_size,
        )
    return ax


def barplot(data: pd.DataFrame, palette: list[str], font_size: int, ax: Axes) -> Axes:
    # Main barplot;
    ax: Axes = sns.barplot(
        data=data,
        x="dataset",
        y="value",
        hue="model",
        palette=palette,
        ax=ax,
        saturation=0.9,
        linewidth=0.5,
        edgecolor="#2f2f2f",
    )

    # Update the x-axis and y-axis labels and settings;
    ax.set_xlabel("Dataset", fontsize=font_size, labelpad=2)
    ax.set_ylabel("Value", fontsize=font_size, labelpad=1)

    # Add the arrow to denote that higher values are better;
    ax = higher_is_better_arrow(
        ax,
        linewidth=0.6,
    )

    # Add a grid to the y-axis. Make sure that the grid is drawn below the bars, instead of above;
    ax.set_axisbelow(True)
    ax.grid(axis="y", linestyle="--", linewidth=0.5)

    # Add labels to bars;
    ax = add_relative_performance_labels_to_bars(ax, font_size=font_size * 0.35)

    # Update the x-axis tick labels;
    x_tick_labels: list[Text] = ax.get_xticklabels()
    ax.set_xticklabels([x.get_text().replace("_", " ").title() for x in x_tick_labels])

    # Add a vertical line to separate the "Average" category from the other categories;
    ax.axvline(x=0.5, color="#2f2f2f", linewidth=0.5, linestyle="--")

    # Set the y-axis limits;
    ax.set_ylim(0, 1.75)
    # Set the y-axis ticks so that we have a tick in correspondence of the y-axis limits;
    ax.yaxis.set_major_locator(LinearLocator(8))
    # Set the y-axis tick labels to have 2 decimal places;
    ax.yaxis.set_major_formatter(lambda x, pos: f"{x:.2f}")

    # Set the font size of the x-axis and y-axis tick labels;
    ax.tick_params(
        axis="both",
        labelsize=font_size * 0.7,
        pad=1,
    )

    # Update the legend labels;
    handles, _ = ax.get_legend_handles_labels()
    # 🌞 To place the legend below the plot, we first delete the existing legend,
    # and then plot a new one, associated to the Figure instead of the Axes;
    ax.get_legend().remove()
    legend = fig.legend(
        handles=handles,
        labels=["Baseline", "State-of-the-art", "Our model"],
        fontsize=font_size,
        loc="lower center",  # 🌞 Place the legend below the plot, at the center;
        bbox_to_anchor=(0.55, 0),  # 🌞 Place the legend below the plot, at the center of the Axes;
        ncol=len(data["model"].unique()),  # 🌞 Three columns, horizontal legend;
        handlelength=1.3,  # 🌞 Reduce the width of the legend handles, the "rectangles";
        handletextpad=0.4,  # 🌞 Reduce the space between the legend handles and the legend labels;
        columnspacing=1,  # 🌞 Reduce the space between the legend columns;
        edgecolor="#2f2f2f",  # 🌞 Color of the legend border;
        framealpha=1,  # 🌞 Opacity of the legend border, make it fully opaque;
        fancybox=False,  # 🌞 Do not draw a rounded box, but a box with straight corners;
    )
    legend.get_frame().set_linewidth(0.5)  # 🌞 Width of the legend border;
    return ax


############
# Plotting #
############

# Increase quality of plots in the notebook,
# and prevent automatic adjustments to the margins;
set_matplotlib_formats("retina", bbox_inches=None)
plt.rcdefaults()

# Increase the width of the axes border, and of the ticks.
# We do so by changing the global parameter;
plt.rcParams["axes.linewidth"] = 0.7
plt.rcParams["xtick.major.width"] = 0.7
plt.rcParams["ytick.major.width"] = 0.7

# Set Arial as font
plt.rcParams["font.family"] = "Arial"

# Load the data
data: pd.DataFrame = load_data()

# Set the width to 3.5 inches, and a 16:9 aspect ratio;
figure_size: tuple[float, float] = (3.5, 3.5 * 9 / 16)

# Color palette for the bars;
palette = ["#E8CFB5", "#81C798", "#358787"]

# Create a figure with the specified size;
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figure_size, dpi=200)

# Adjust the margins of the figure;
fig.subplots_adjust(
    left=0.11,
    right=0.99,
    bottom=0.3,  # 🌞 Increase the bottom margin to make space for the legend;
    top=0.97,
)

# Main plot;
ax = barplot(data=data, palette=palette, font_size=8, ax=ax)

# Save the plot as PDF and PNG;
plt.savefig("../plots/notebooks/1_getting_started_with_barplots/our_amazing_barplot_v3.pdf")
plt.savefig("../plots/notebooks/1_getting_started_with_barplots/our_amazing_barplot_v3.png", dpi=300)

# Show the barplot;
plt.show()