# Introduction to data Vis in Python: Day 2

Welcome to day 2! This session is a lot more hands-on: we'll take some of the concepts we learned in day 1, and start applying them to lots of different plot types.

## Recap

Let's start off with a very brief recap of our basic Matplotlib, in the form of a quick exercise!

In [None]:
# first, we need to import our libraries

import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
# generate some quick random data

x = np.random.rand(50)
y = np.random.rand(50)
z = np.linspace(0, 1, 50)

In [None]:
# Reminder: this is your scatter plot with a colourbar

fig, ax = plt.subplots()
sp = ax.scatter(x, y, c=z, s=z*200, alpha=0.8)
ax.set_title('Scatter Plot')
ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')

fig.colorbar(sp, label="Colour bar label",
             orientation="horizontal", format="{x:.0%}",
             location="top", pad=0.10)

## Exercise 1

Modify the snippet of code below to remove the colourbar and instead of a scatterplot, build a line plot.

Useful links:

- [ColorBrewer](https://colorbrewer2.org/#type=sequential&scheme=BuGn&n=3)
- [Matplotlib lineplots](https://matplotlib.org/stable/gallery/lines_bars_and_markers/simple_plot.html)
- [Matplotlib linestyles](https://matplotlib.org/stable/gallery/lines_bars_and_markers/linestyles.html)
- [Setting Matplotlib colours with hexcodes etc.](https://matplotlib.org/stable/gallery/color/color_demo.html)

In [None]:
# Exercise: Modify the below code/delete lines to create a
# line plot, of x vs. z
# In a colour from ColorBrewer

fig, ax = plt.subplots()
sp = ax.scatter(x, y, c=z, s=z*200, alpha=0.8)
ax.set_title('Scatter Plot')
ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')

fig.colorbar(sp, label="Colour bar label",
             orientation="horizontal", format="{x:.0%}",
             location="top", pad=0.10)

## Plotting multiple series

Now on to new content!

When plotting multiple series, our plots can quickly become a bit busy.

How do we move the legend around?

- [Legend guide](https://matplotlib.org/stable/users/explain/axes/legend_guide.html)

In [None]:
fig, ax = plt.subplots()
sp = ax.scatter(x, y, c=z, s=z*200, alpha=0.8, label="Data with z value")
ax.set_title('Scatter Plot')
ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')

ax.scatter(x, -y-0.1, c="grey", s=200, alpha=0.6, edgecolor="black", label="Other data")

fig.colorbar(sp, label="Colour bar label",
             orientation="horizontal", format="{x:.0%}",
             location="top", pad=0.10)

ax.legend()

In [None]:
# How do we move the legend out of the way?

fig, ax = plt.subplots()
sp = ax.scatter(x, y, c=z, s=z*200, alpha=0.8, label="Data with z value")
ax.set_title('Scatter Plot')
ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')

ax.scatter(x, -y-0.1, c="grey", s=200, alpha=0.6, edgecolor="black", label="Other data")

fig.colorbar(sp, label="Colour bar label",
             orientation="horizontal", format="{x:.0%}",
             location="top", pad=0.10)

ax.legend(loc='center', bbox_to_anchor=(0.5, -0.02),
          bbox_transform=fig.transFigure, frameon=False,
          ncols=2)

## Exercise 2: Explore the legend!

Modify the cell below to move the legend around.

Try:

- Removing all other arguments except `loc`
- Include `bbox_to_anchor` but not `bbox_transform`
- Figure out what `ncols` and `frameon` do; what are their defaults?

To view the documentation for a function, you can use one of the following:

- Type the function and wait for the auto docstring to show: `ax.legend()`
- Type `ax.legend()`, then hover your mouse over the function; you can then select to open this in a panel to the side
- Or, type `ax.legend?` and hit Ctrl + Enter to run, and open the docstring in a side panel

In [None]:
ax.legend()

In [None]:
fig, ax = plt.subplots()
sp = ax.scatter(x, y, c=z, s=z*200, alpha=0.8, label="Data with z value")
ax.set_title('Scatter Plot')
ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')

ax.scatter(x, -y-0.1, c="grey", s=200, alpha=0.6, edgecolor="black", label="Other data")

fig.colorbar(sp, label="Colour bar label",
             orientation="horizontal", format="{x:.0%}",
             location="top", pad=0.10)

ax.legend(loc='center', bbox_to_anchor=(0.5, -0.02),
          bbox_transform=fig.transFigure, frameon=False,
          ncols=2)

## Multi-Panel Plots

Multi-panel, or gridded/faceted plots can help to make complex plots easier to follow, but are more easily comparable than totally separate figures.

* Easy comparison of trends, distributions, or relationships across different groups, conditions, or time periods.
* Reduces clutter and cognitive load: no need to overload *multiple encodings* per point
* All plots share the same overall figure space, often with aligned axes, ensuring a consistent visual context for interpretation.
* Helps communicate by placing the plots in a logical sequence or arrangement.

### General Approach with `plt.subplots()`:

The basic syntax is `fig, ax = plt.subplots(nrows, ncols)`, where:
*   `nrows` specifies the number of rows of subplots.
*   `ncols` specifies the number of columns of subplots.

If `nrows` or `ncols` is greater than 1, `ax` will be a NumPy array of `Axes` objects, which you can then iterate over or index directly to plot on specific subplots. This approach automatically creates a grid of equally sized subplots.

**Example Use Case**: Comparing a variable's distribution across several categories, or showing a time series for different regions.

Useful link: [quick DIY tool to generate grids](https://quick-subplots.streamlit.app/)

In this example, we're going to step through a few cells to create the example dataset, and then plot it!

In [None]:
# 1. Generate sample data for 4 different plots
# Data for Line Plot
x_line = np.linspace(0, 10, 100)
y_line = np.sin(x_line) + np.random.rand(100) * 0.5

# Data for Scatter Plot
x_scatter = np.random.rand(50) * 10
y_scatter = np.random.rand(50) * 10 + x_scatter * 0.5

# Data for Bar Chart
categories = ['A', 'B', 'C', 'D']
values = np.random.randint(10, 100, len(categories))

# Data for Histogram
hist_data = np.random.randn(1000)

In [None]:
# mini exercise: try viewing the documentation for one of the above functions,
# for example:

np.random.randn()

In [None]:
# plt.subplots() will create a 2D array of "axes" objects,
# we will use array indexing to specify these.

# Let's quickly refresh on this!

list_1 = [0, 1, 2]
list_2 = [3, 4, 5]

list_of_lists = [list_1,
                 list_2]

example_array = np.array(list_of_lists)

In [None]:
example_array

In [None]:
example_array.shape # nrows, ncolums

If we imagine every number is a "subplot" or axes on our figure, how do we access them?

```text
array([[0, 1, 2],
       [3, 4, 5]])
```

We use [array indexing!](https://numpy.org/doc/stable/user/basics.indexing.html)

```text
  cols: 0  1  2    rows:
array([[0, 1, 2],   0
       [3, 4, 5]])  1
```

Access a specific value/point/axes with the coordinates with `object_name[row, column]`

In [None]:
example_array[0,2]

Back to our gridded plots:

- We create the plot using subplots like we have done before
- Instead of creating a single `fig, ax` with the function, we will create a single `fig` but multiple `ax` objects
- We will access the different `ax` objects using our array indexing: `ax[0,0]` etc.!

In [None]:
# 2. Create a figure and a 2x2 grid of subplots
fig, ax = plt.subplots(2, 2)

# 3. Plot a different type of graph on each subplot
# Top-left: Line Plot
ax[0, 0].plot(x_line, y_line, color='purple', linewidth=2)
ax[0, 0].set_title('Line Plot: Trend Over Time')
ax[0, 0].set_xlabel('Time')
ax[0, 0].set_ylabel('Value')
ax[0, 0].grid(True)

# Top-right: Scatter Plot
ax[0, 1].scatter(x_scatter, y_scatter, color='blue', alpha=0.7)
ax[0, 1].set_title('Scatter Plot: X vs Y')
ax[0, 1].set_xlabel('X Variable')
ax[0, 1].set_ylabel('Y Variable')
ax[0, 1].grid(True)

# Bottom-left: Bar Chart
ax[1, 0].bar(categories, values, color='skyblue')
ax[1, 0].set_title('Bar Chart: Categorical Comparison')
ax[1, 0].set_xlabel('Categories')
ax[1, 0].set_ylabel('Values')
ax[1, 0].grid(axis='y', linestyle='--', alpha=0.7)

# Bottom-right: Histogram
ax[1, 1].hist(hist_data, bins=30, edgecolor='black', alpha=0.7, color='green')
ax[1, 1].set_title('Histogram: Distribution')
ax[1, 1].set_xlabel('Value')
ax[1, 1].set_ylabel('Frequency')
ax[1, 1].grid(True, linestyle='--', alpha=0.6)

## Fixing overlapping/cramped plots

Now, we have a common issue with this plot, that lots of people run into when creating gridded/faceted plots: overlap of labels and titles.

How do we fix this?

1. If things are looking cramped, we can define a specific figure size
2. We use the magic "tight_layout" function!
3. Add the argument `layout='constrained'` to the original `plt.subplots` call:`fig, ax = plt.subplots(2, 2, layout="constrained")`

Figsize will often fix the issue, if we make the plot bigger: but sometimes we want to keep a nice small plot! How do we approach it then?

In [None]:
# try just modifying figsize first
fig, ax = plt.subplots(2, 2, figsize=(12, 10))


# Top-left: Line Plot
ax[0, 0].plot(x_line, y_line, color='purple', linewidth=2)
ax[0, 0].set_title('Line Plot: Trend Over Time')
ax[0, 0].set_xlabel('Time')
ax[0, 0].set_ylabel('Value')
ax[0, 0].grid(True)

# Top-right: Scatter Plot
ax[0, 1].scatter(x_scatter, y_scatter, color='blue', alpha=0.7)
ax[0, 1].set_title('Scatter Plot: X vs Y')
ax[0, 1].set_xlabel('X Variable')
ax[0, 1].set_ylabel('Y Variable')
ax[0, 1].grid(True)

# Bottom-left: Bar Chart
ax[1, 0].bar(categories, values, color='skyblue')
ax[1, 0].set_title('Bar Chart: Categorical Comparison')
ax[1, 0].set_xlabel('Categories')
ax[1, 0].set_ylabel('Values')
ax[1, 0].grid(axis='y', linestyle='--', alpha=0.7)

# Bottom-right: Histogram
ax[1, 1].hist(hist_data, bins=30, edgecolor='black', alpha=0.7, color='green')
ax[1, 1].set_title('Histogram: Distribution')
ax[1, 1].set_xlabel('Value')
ax[1, 1].set_ylabel('Frequency')
ax[1, 1].grid(True, linestyle='--', alpha=0.6)

In [None]:
# using tight_layout
fig, ax = plt.subplots(2, 2)


# Top-left: Line Plot
ax[0, 0].plot(x_line, y_line, color='purple', linewidth=2)
ax[0, 0].set_title('Line Plot: Trend Over Time')
ax[0, 0].set_xlabel('Time')
ax[0, 0].set_ylabel('Value')
ax[0, 0].grid(True)

# Top-right: Scatter Plot
ax[0, 1].scatter(x_scatter, y_scatter, color='blue', alpha=0.7)
ax[0, 1].set_title('Scatter Plot: X vs Y')
ax[0, 1].set_xlabel('X Variable')
ax[0, 1].set_ylabel('Y Variable')
ax[0, 1].grid(True)

# Bottom-left: Bar Chart
ax[1, 0].bar(categories, values, color='skyblue')
ax[1, 0].set_title('Bar Chart: Categorical Comparison')
ax[1, 0].set_xlabel('Categories')
ax[1, 0].set_ylabel('Values')
ax[1, 0].grid(axis='y', linestyle='--', alpha=0.7)

# Bottom-right: Histogram
ax[1, 1].hist(hist_data, bins=30, edgecolor='black', alpha=0.7, color='green')
ax[1, 1].set_title('Histogram: Distribution')
ax[1, 1].set_xlabel('Value')
ax[1, 1].set_ylabel('Frequency')
ax[1, 1].grid(True, linestyle='--', alpha=0.6)

fig.tight_layout()

In [None]:
# using layout="constrained"
fig, ax = plt.subplots(2, 2, layout="constrained")


# Top-left: Line Plot
ax[0, 0].plot(x_line, y_line, color='purple', linewidth=2)
ax[0, 0].set_title('Line Plot: Trend Over Time')
ax[0, 0].set_xlabel('Time')
ax[0, 0].set_ylabel('Value')
ax[0, 0].grid(True)

# Top-right: Scatter Plot
ax[0, 1].scatter(x_scatter, y_scatter, color='blue', alpha=0.7)
ax[0, 1].set_title('Scatter Plot: X vs Y')
ax[0, 1].set_xlabel('X Variable')
ax[0, 1].set_ylabel('Y Variable')
ax[0, 1].grid(True)

# Bottom-left: Bar Chart
ax[1, 0].bar(categories, values, color='skyblue')
ax[1, 0].set_title('Bar Chart: Categorical Comparison')
ax[1, 0].set_xlabel('Categories')
ax[1, 0].set_ylabel('Values')
ax[1, 0].grid(axis='y', linestyle='--', alpha=0.7)

# Bottom-right: Histogram
ax[1, 1].hist(hist_data, bins=30, edgecolor='black', alpha=0.7, color='green')
ax[1, 1].set_title('Histogram: Distribution')
ax[1, 1].set_xlabel('Value')
ax[1, 1].set_ylabel('Frequency')
ax[1, 1].grid(True, linestyle='--', alpha=0.6)


>Note: the [matplotlib documentation](https://matplotlib.org/stable/gallery/subplots_axes_and_figures/broken_axis.html) recommends you use grids of plots in order to create "broken axes"!

## Exercise 3: faceted plots with heatmaps

In this exercise, create a grid similar to the one above, except include a dataset with a colour map that requires a colourbar.

- Where will you put it? How do things work with a grid? Is it messy?
- Where will you put your legend?

Have a look at ["Placing colourbars"](https://matplotlib.org/stable/users/explain/axes/colorbar_placement.html)

Remember, this was our scatterplot with a colour map:

```python
fig, ax = plt.subplots()
sp = ax.scatter(x, y, c=z, s=z*200, alpha=0.8, label="Data with z value")
ax.set_title('Scatter Plot')
ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')

ax.scatter(x, -y-0.1, c="grey", s=200, alpha=0.6, edgecolor="black", label="Other data")

fig.colorbar(sp, label="Colour bar label",
             orientation="horizontal", format="{x:.0%}",
             location="top", pad=0.10)

ax.legend(loc='center', bbox_to_anchor=(0.5, -0.02),
          bbox_transform=fig.transFigure, frameon=False,
          ncols=2)
```

In [None]:
# Example skeleton code: fill this in!

fig, ax = plt.subplots(2, 2, sharex=True, sharey=True)

# Top-left: ax[0, 0]

# Top-right: ax[0, 1]

# Bottom-left: ax[1, 0]

# Bottom-right: ax[1, 1]

In [None]:
# Incomplete solution
# Please try to come up with a solution on your own first!
# If you get stuck, have a look at this possible answer

fig, ax = plt.subplots(2, 2)

# Top-left: ax[0, 0]
# my scatterplot with a colourmap
scatter_plot = ax[0, 0].scatter(x, y, c=z, s=z*200, alpha=0.8, label="Data with z value")

fig.colorbar(scatter_plot, label="Colour bar label",)
            #  orientation="horizontal", format="{x:.0%}",
            #  location="top", pad=0.10)

# I'm just putting the same plot but grey in the other locations

# Top-right: ax[0, 1]
scatter_plot = ax[0, 1].scatter(x, y, c="grey", alpha=0.8, label="Data with no z value")

# Bottom-left: ax[1, 0]
scatter_plot = ax[1, 0].scatter(x, y, c="grey", alpha=0.8)

# Bottom-right: ax[1, 1]
scatter_plot = ax[1, 1].scatter(x, y, c="grey", alpha=0.8)

fig.legend()

In [None]:
# Incomplete solution
# Please try to come up with a solution on your own first!
# If you get stuck, have a look at this possible answer

fig, ax = plt.subplots(2, 2, layout="constrained")
# note how constrained layout effects your legend

# Top-left: ax[0, 0]
# my scatterplot with a colourmap
scatter_plot = ax[0, 0].scatter(x, y, c=z, s=z*200, alpha=0.8, label="Data with z value")

fig.colorbar(scatter_plot, label="Colour bar label",)
            #  orientation="horizontal", format="{x:.0%}",
            #  location="top", pad=0.10)

# I'm just putting the same plot but grey in the other locations

# Top-right: ax[0, 1]
scatter_plot = ax[0, 1].scatter(x, y, c="grey", alpha=0.8, label="Data with no z value")

# Bottom-left: ax[1, 0]
scatter_plot = ax[1, 0].scatter(x, y, c="grey", alpha=0.8)

# Bottom-right: ax[1, 1]
scatter_plot = ax[1, 1].scatter(x, y, c="grey", alpha=0.8)

fig.legend()

From a design point of view, how might we improve this?

- Remove "clutter" from the centre of the plot: move the colour bar to a periphery
- Put the legend and colourbar on the same side

So I have two possible options here that I think might immediately make this tider:

1. Move the colourmap plot to the top right instead of left, so the colourbar is on the outside of the plot
2. Move colourbar to the horizontal orientation over the plot, alongside the legend

In [None]:
# Incomplete solution
# Please try to come up with a solution on your own first!
# If you get stuck, have a look at this possible answer

fig, ax = plt.subplots(2, 2, layout="constrained")
# note how constrained layout effects your legend

# Top-left: ax[0, 0]
# my scatterplot with a colourmap
scatter_plot = ax[0, 0].scatter(x, y, c=z, s=z*200, alpha=0.8, label="Data with z value")

fig.colorbar(scatter_plot, label="Colour bar label",
             orientation="horizontal", format="{x:.0%}",
             location="top", pad=0.10)

# I'm just putting the same plot but grey in the other locations

# Top-right: ax[0, 1]
scatter_plot = ax[0, 1].scatter(x, y, c="grey", alpha=0.8, label="Data with no z value")

# Bottom-left: ax[1, 0]
scatter_plot = ax[1, 0].scatter(x, y, c="grey", alpha=0.8)

# Bottom-right: ax[1, 1]
scatter_plot = ax[1, 1].scatter(x, y, c="grey", alpha=0.8)

fig.legend()

In [None]:
# I want to make it tidier!

fig, ax = plt.subplots(2, 2, layout="constrained", sharex=True, sharey=True)

# Top-left: ax[0, 0]
# my scatterplot with a colourmap
scatter_plot = ax[0, 0].scatter(x, y, c=z, s=z*200, alpha=0.8, label="Data with z value")

fig.colorbar(scatter_plot, label="Colour bar label",
             orientation="horizontal", format="{x:.0%}",
             location="top", pad=0.10)

# I'm just putting the same plot but grey in the other locations

# Top-right: ax[0, 1]
scatter_plot = ax[0, 1].scatter(x, y, c="grey", alpha=0.8, label="Data with no z value")

# Bottom-left: ax[1, 0]
scatter_plot = ax[1, 0].scatter(x, y, c="grey", alpha=0.8)

# Bottom-right: ax[1, 1]
scatter_plot = ax[1, 1].scatter(x, y, c="grey", alpha=0.8)


# fig.legend(loc='center left', bbox_to_anchor=(0.62, 0.92),
#           bbox_transform=fig.transFigure, frameon=False)

## Exercise 4: Save out your plots!

For the following plot:

- Add an appropriate title
- Add appropriate x and y labels
- Export it using the `fig.savefig()` function.

Useful links:

- [savefig function docs](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html)

In [None]:
# start with this script and modify it!

fig, ax = plt.subplots(2, 2, layout="constrained", sharex=True, sharey=True)

# Top-left: ax[0, 0]
# my scatterplot with a colourmap
scatter_plot = ax[0, 0].scatter(x, y, c=z, s=z*200, alpha=0.8, label="Data with z value")

fig.colorbar(scatter_plot, label="Colour bar label",
             orientation="horizontal", format="{x:.0%}",
             location="top", pad=0.10)


# Top-right: ax[0, 1]
scatter_plot = ax[0, 1].scatter(x, y, c="grey", alpha=0.8, label="Data with no z value")

# Bottom-left: ax[1, 0]
scatter_plot = ax[1, 0].scatter(x, y, c="grey", alpha=0.8)

# Bottom-right: ax[1, 1]
scatter_plot = ax[1, 1].scatter(x, y, c="grey", alpha=0.8)


fig.legend(loc='center left', bbox_to_anchor=(0.62, 0.92),
          bbox_transform=fig.transFigure, frameon=False)

# hints:

fig.suptitle("Overall title")

ax[0, 0].set_ylabel("Y value [unit]")

# fill in the rest!

## Test out save functions

Try using different file formats:

- pdf
- svg
- png

And for raster formats (like png), try changing the dpi.

Note: it's important to set your overall fig size if you are modifying dpi; dpi or dots per inch is dependent on the real size of your plot!