## Subplotting

So far, we've learned how to make single plots by making calls directly from pyplot using commands like 
`plt.plot()` or `plt.hist()`. However, sometimes it's useful to show several plots side-by-side in the same
figure. In this lecture, we'll talk about some of matplotlib's subplotting features, and discuss how to
construct a scatterplot matrix. This goes back into some matplotlib architecture discussion, and while I'll
be showing you how to do this with scatterplots, you can map almost anything to a matrix

In [None]:
%matplotlib inline
# Let's bring in pyplot and numpy
import matplotlib.pyplot as plt
import numpy as np

# Let's start with creating a single sine wave plot - this is just the data for a sine wave
x = np.linspace(0, 4 * np.pi, 200)
y = np.sin(x)

# And let's take a look at that
plt.plot(x, y)
# Let's also put a title on the axes
plt.gca().set_title('A single plot')

In [None]:
# Now, let's say we want to make a grid of plots. Well, once we've created a figure object, we can add
# subplots by using the add_subplot() function. Notice that it takes in three parameters, the first two
# numbers specifies that we have 2 rows and 3 columns. The last number is an index that lets you refer to a
# specific subplot.

# So, let's create a basic figure
fig = plt.gcf()
# Now let's iterate over the 6 potential spots in our figure
for i in range(1, 7):
    # let's not plot something if we are at position 5 or 3, we'll leave these as holes
    if i != 5 and i != 3:
        # Now let's add a subplot, we specify this as the overall structure we are expecting the figure to
        # take, so 2 rows and 3 columns, and the position of this item in the figure. Note that while we are
        # iterating linearly the figure then is mapping our index into this 2x3 row/column space, so the
        # numbering goes from left to right and top to bottom
        ax = fig.add_subplot(2, 3, i)
        # And we'll just add some text to the figures to make this more clear. Remember from the lecture on
        # annotation that text goes at the bottom left corner by default, and we are using relative
        # positioning
        ax.text(0.5, 0.2, str((2, 3, i)),fontsize=18, ha='center')

In [None]:
# Great, so we see our figure now contains 6 subplots, and two of those (index 3 and 5) are empty. Important
# to note here is that the indexing of subplots starts at 1 and *not* zero.

# We can see that each subplot has its own axes object, which has rendered the x and y axis object
# respectively. Actually, we can see things look super crowded here too, and there are axis tick labels
# overwriting other axes objects.

# Now, you don't have to create your figure like this, iteratively as you go. Instead, you can call the pyplot
# subplots() function and specify a shape. This function returns the figure object and an ndarray of AxesSubplot
# objects. Let's take a look
fig, axes_array = plt.subplots(2, 3)
print(axes_array)

In [None]:
# Now, like most functions in matplotlib, we can pass a list of keyword arguments to the function to control
# the formatting of a particular artist. And, the available keywords are often not shared in the docs directly,
# for instance in the subplot documentation this is just listed as **fig_kw, which means a dictionary of other
# keywords will just be unpacked, and "All additional keyword arguments are passed to the pyplot.figure call."
# This means you most certainly will have to experiment and hunt through function calls when using matplotlib.

# But we've already seen that we can pass the figsize as a tuple to the figure() function, so that means we
# should be able to pass that to the subplots() function too
plt.subplots(2, 3, figsize=(12, 5))

In [None]:
# Ok, so that makes a bigger figure at the same time we are building our subplots. As you can see, the pyplot
# scripting layer can feel like there is a lot of magic going on, especially if you are used to object
# oriented libraries. But with the docs at hand it's actually pretty straight forward.

# let's replot that figure and then adjust the subplots using the subplots_adjust() feature. This is actually
# not available in pyplot per se, so we need to get the current figure to make the adjustment directly
plt.subplots(2, 3, figsize=(12, 5))
# Let's add a 0.35 relative bit of whitespace between axes
plt.gcf().subplots_adjust(hspace=0.35, wspace=0.35) 

In [None]:
# Ok, so those are the mechanics of subplots. But I sort of teased you with a sine wave, let's get back
# to our example and talk about wave interference, and how we can visualize this kind of data.

![wave interference](assets/wave_interference.jpg)

In [None]:
# When two waves collide with each other, as shown in this picture, their amplitudes (or heights in the case
# of the water wave here) are summed together. If the two waves are synchronized, or in phase, this results in
# constructive intereference, and we should see a larger wave. On the other hand, if they are perfectly out of
# phase, then this will result in destructive interference (where the flat dips are). So let's try to plot two
# different waves and show the result of adding these waves, all on one figure.

In [None]:
# First let's specify some x values, these will be shared between our waves
x = np.linspace(0, 4 * np.pi, 220)

# Now I'm going to shift our waves by π. Feel free to play around with this shift value
shift = 1 * np.pi
# Also, I'm going to round our values everywhere to 5 digits, just to prevent an annoying round off error
y1 = np.round_(np.sin(x),5)
y2 = np.round_(np.sin(x + shift),5)
y3 = y1 + y2

# So now we have essentially three waves, y1 and y2 are just sine waves, and y3 is the conmbination of these
# two waves Let's plot thgem all side by side
fig, axs = plt.subplots(1, 3, figsize=(12,3))
axs[0].plot(x, y1)
axs[1].plot(x, y2)
axs[2].plot(x, y3)

In [None]:
# Cool. So when the two waves interfere exactly they cancel out and form a straight line. Now, the previous
# setup isn't ideal since it's kind of hard to see the relationship between these two plots and how it results
# in the third. So instead, let's try stacking them vertically and share the x-axes so that the points line up
# with each other

# We'll double the length of the x-axis (4 cycles instead of 2)
x = np.linspace(0, 8 * np.pi, 200)

# Once again, free to play around with this shift value afterwards
shift = 1 * np.pi
y1 = np.round_(np.sin(x), 5)
y2 = np.round_(np.sin(x + shift),5)
y3 = np.round_(y1 + y2, 5)

# And let's stack the charts this time. We can also indicate that they will share and X and Y axis
fig, axs = plt.subplots(3, 1, figsize=(12,4), sharex=True, sharey=True)
axs[0].plot(x, y1)
axs[1].plot(x, y2)
axs[2].plot(x, y3)

In [None]:
# Now you can see the pattern much more clearly! Just focus on a single point and notice how the peaks and
# troughs of these two curves line up and "cancel" each other out (1 + -1 = 0). There's also no need to
# manually set the y bounds since we've specified that the axes markings should be shared.

In [None]:
# Ok, I want to show you a bit of an easter egg here. While there are a couple of ways to make this a bit more
# easy to explore, I'm going to just do a very simple animation using an animated gif. In this example I'm
# going to use the python imaging library (PIL) to do so. So let's import that now
import PIL

# Now our top image will be static, that's just the first waveform, so let's create that now
x = np.linspace(0, 8 * np.pi, 200)
y1 = np.round_(np.sin(x), 5)

# We're going to hold individual frames of the animated gif in a new variable, called ims
ims=[]

# And lets create a bunch of shift values, ranging from 0 to π. I'm going to create 20 of these, but the more
# you create the more smooth (and slower) your animation is going to be
shifts=np.linspace(0,np.pi,20)

# For each shift value we're going to create a new subplot, just like we did before
for shift in shifts:
    fig, axs = plt.subplots(3, 1, figsize=(12,4), sharex=True, sharey=True)
    y2 = np.round_(np.sin(x + shift),5)
    y3 = np.round_(y1 + y2, 5)
    
    
    axs[0].plot(x, y1)
    # axs[0].set_ylim(-1,1)
    axs[1].plot(x, y2)
    # axs[1].set_ylim(-1,1)
    axs[2].plot(x, y3)
    # Now, we need to set the y axis size of one of the plots from -2 to 2, because the waves will interfere
    # and form peaks. Remember that we set the sharey already in the image, so we don't have to worry about
    # setting this on each axes object
    axs[2].set_ylim(-2,2)
    
    # The next bit of code just renders the figure, then copies it from the figure canvas (the rendered part)
    # as an ndarray of bytes and into a PIL image. The PIL function frombytes() will read a set of bytes and
    # turn that into a PIL image object
    canvas=plt.get_current_fig_manager().canvas
    canvas.draw()
    ims.append(PIL.Image.frombytes('RGB', canvas.get_width_height(), canvas.tostring_rgb()))
    
    # Now we can free up the memory from the plot
    plt.close()
    
# The final step is to save our image list into a single file. We can do this by taking the first image in the
# list and calling save with the save_all parameter equal to True, and pass in the rest of the images we want
# appended
ims[0].save('out.gif', save_all=True, append_images=ims[1:])

So, depending upon how many images you decided to generate that might take a little bit to run. The result
is that we have an image on the file system called "out.gif". We can now decide to render it inline in a
markdown cell, like this:

![](out.gif)

In [None]:
# And of course, you can just head to the jupyter filesystem by clicking the logo in the upper left and
# navigate until you find that image too, and look at it in the browser or even download it

## SPLOMS

So, we now understand why there are abstractions of Axes in a figure, because a figure might have several
Axes objects which show multiple views of data. A common data science visual exploration technique is called
the SPLOM, which stands for scatterplot matrices. These are particularly useful for getting the relationship
between a number of different variables from a quick glance. For this example, I want to load in a dataframe
of information regarding people's credit card balance information and explore that.

In [None]:
# Let's bring in pandas and load our DataFrame
import pandas as pd
df = pd.read_csv("assets/Credit.csv")
df.head()

We're only going to be looking at a subset of this data set. Specifically, we're interested in exploring
relationships between people's

- Income (in units of 10,000 US dollars) 
- Rating (credit rating) 
- Age (age in years) 
- Education (number of years of education) 
- Balance (average credit card balance in USD)

But in case you're curious, the descriptions for all of these variables can be found at the link I've listed
below

https://vincentarelbundock.github.io/Rdatasets/doc/ISLR/Credit.html

In [None]:
# Let's first capture a list of the variables we are interested in
cols = ['Income', 'Rating', 'Age', 'Education', 'Balance']

# Now we need to create a grid of subplots of size width and height equal to the number of different variables
# we want to explore, in this case that's a 5x5 grid
fig, axs = plt.subplots(len(cols), len(cols), figsize=(10,10))

# Now we want to iterate across each column in our dataframe and compare it to each other column in our
# DataFrame. We'll do this with a couple of nested for loops
for i in range(len(cols)):
    for j in range(len(cols)):
        # Now we just want to plot a scatter plot comparing the columns i and j. Sometimes this will be the
        # same column, so we would expect to see a diagnol line trend. I'm going to set the marker size to 5
        # just to make things a bit more clear
        axs[i,j].scatter(df[cols[j]], df[cols[i]], s=5)

        # Also, we've seen that when we plot multiple Axes things get cluttered with Axis tickmarks and
        # labels, so let's turn those off
        axs[i,j].get_xaxis().set_visible(False)
        axs[i,j].get_yaxis().set_visible(False)

        # Then we'll turn them back on only if we are the last row...
        if i == len(cols) - 1:
            axs[i,j].get_xaxis().set_visible(True)
            axs[i,j].set_xlabel(cols[j])
            
        # ...and similarly, only show the y-axis labels for the first column.
        if j == 0:
            axs[i,j].get_yaxis().set_visible(True)
            axs[i,j].set_ylabel(cols[i])

# Now let's take a look at our data!

Ok, this takes a bit to run, but we can see that there is quite a bit of information contained in this
compact figure! For the most part, many of the variables seem to be uncorrelated. For instance, there does
not seem to be any relationship between age (which we can see in the third row) and education (which is the
fourth column) as there are points all over the space. However, there does seem to be a relationship between
one's credit rating (the second row from the top) and income (the first column) in that there is a large
positive trend. Do you see any other interesting trends in this data?

Now, there was a lot we had to do here to build this SPLOM, but it turns out this is a very useful visual
exploration technique, and a lot of libraries can make it much easier. In fact, in the next video, Nia
Dowell  will demonstrate how to explore this kind of data much faster using the Seaborn library. And, it's
often convenient to do a quick google and grab the library that seems to look like it answers the question
you have quickly. And there's nothing wrong with that! But sometimes you need to understand the mechanics
underneath, such as when you want to build a custom visualization or when you need to tweak output in a
specific way. So my goal in showing you how to build a SPLOM by hand is to build your confidence that you
can work with a broad range of tools as the situation arises.