# Intro to plotting with <code>matplotlib</code>
So you want to learn how to make plots with Python? <code>matplotlib</code> is by far the most commonly used package, and it contains basically every customization option you could hope for. This notebook will walk you through the basics, but when you want to branch out, don't be afraid to turn to tools like Stack Overflow and ChatGPT for help customizing your plot!

Before we get started, let's import <code>matplotlib</code> so we can actually use it. (We'll also import <code>numpy</code>, which will help us generate fake data for plotting.) The typical shorthand name for <code>matplotlib</code> is <code>plt</code> (meaning "plot").

In [None]:
import numpy as np
import matplotlib.pyplot as plt

We'll also read in the dataset that we were working with earlier, which is a subset of data from the APOGEE survey. The survey makes high-resolution, spectroscopic observations of Milky Way red giant stars.

In [None]:
#Importing an data structure from the astropy package that will help read in our file
from astropy.table import Table

#Reading in the table
data = Table.read('HR_sample_data.fit') 
data

In [None]:
#This is how you access the data stored in each column from the table
#The syntax is: table_name['column_name']
ra = data['RAICRS']
dec = data['DEICRS']
magnitude = data['Vmag']
color = data['B-V']
parallax = data['Plx']

## Basic plotting
First, we'll examine three basic types of plots, which will also demonstrate the core functionalities of the <code>matplotlib</code> package.

### Line plot

In [None]:
#Generating some fake data
x = np.linspace(0, 1, 101) 

In [None]:
#Making a simple line plot
#For reference: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.plot.html

#Passing in two variables: x and y
plt.plot(x, x**2) 
#Setting the axis labels (you can pass in any string you'd like)
plt.xlabel('x')
plt.ylabel('y')
#Setting the title (you can pass in any string you like)
plt.title('Figure 1')
#This command tells your notebook to display the plot
plt.show()

In [None]:
#To plot additional lines, simply add another plt.plot command
#matplotlib will automatically assign different colors to the lines
plt.plot(x, x**2)
plt.plot(x, np.sqrt(x))
plt.xlabel('x')
plt.ylabel('y')
plt.title('Figure 2')
plt.show()

In [None]:
#If you'd like to control the color of the lines, use color='color'
#List of valid colors: https://matplotlib.org/stable/gallery/color/named_colors.html#css-colors
#You can also feed in RGB codes
plt.plot(x, x**2, color='red')
plt.plot(x, np.sqrt(x), color='black')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Figure 2')
plt.show()

In [None]:
#To add a legend to the plot, you need to associate a "label" to each plt.plot call
plt.plot(x, x**2, color='red', label='y=x^2')
plt.plot(x, np.sqrt(x), color='black', label='y=sqrt(x)')

#This line actually adds the legend, using the labels you just defined
#It can be located anywhere below the lines that define the labels (doesn't have to be immediately after)
plt.legend()

#You can set the axis limits using the following syntax:
plt.xlim(0,1)
plt.ylim(0,1)

#Stuff we've already covered:
plt.title('Figure 3')
plt.xlabel('x')
plt.ylabel('y')
plt.show()

### Histogram

Let's look at the APOGEE data now. What if we wanted to examine the distribution of masses for our stars? A histogram is a good choice here because the masses are a 1D dataset. To make a histogram in <code>matplotlib</code>, we can use the following function:

In [None]:
#Making a simple histogram:
#For reference: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.hist.html
plt.hist(magnitude, color='green')
plt.xlabel('V_mag', size=12) #size=12 will set the fontsize for the label
plt.ylabel('count', size=12)
plt.show()

In [None]:
#You can specify the bins of the histogram either by number (by passing an int) or by a list of bin boundaries
bins = 100
bins = np.arange(0, 13, 0.1)
plt.hist(magnitude, color='green', bins=bins)
plt.xlabel('V_mag', size=12) 
plt.ylabel('count', size=12)
plt.show()

In [None]:
#You can change the histogram to an outline by passing histtype='step'
#This is really useful for overplotting multiple histograms
plt.hist(magnitude, color='green', histtype='step')
plt.xlabel('V_mag', size=12) 
plt.ylabel('count', size=12)
plt.show()

### Scatterplot
Let's try to plot luminosity versus temperature for all the stars in our sample.

In [None]:
#Naively, we might expect to do something like this:
plt.plot(color, magnitude)
plt.xlabel('B-V (color)', size=12)
plt.ylabel('V_mag (magnitude)', size=12)
plt.show()

Oh no! This doesn't look right at all. Well, that's because <code>plt.plot()</code> connects points with lines -- this may be good for plotting trajectories, or orbits for example, but not these data points. Here we should use <code>plt.scatter()</code> instead. 

In [None]:
#Making a simple scatterplot:
#For reference: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.scatter.html
plt.scatter(color, magnitude)
plt.xlabel('B-V (color)', size=12)
plt.ylabel('V_mag (magnitude)', size=12)
plt.show()

(Technically, <code>plt.plot()</code> is capable of plotting both lines and points, so you can also achieve the intended effect by adding <code>linestyle='None', marker='o'</code> to the <code>plt.plot()</code> call. However, <code>plt.scatter()</code> is generally the better choice, since it was specifically designed to make scatterplots.)

This is an improvement, but we can still do much better in terms of visualization in a couple of ways. One of the confusing things about astronomy is that a smaller magnitude actually corresponds to a brighter star. This means that our plot currently has brighter stars at the bottom of the y-axis and fainter stars at the top. It would be better to invert the y-axis so that brightness increases upwards, as we would typically expect.

In [None]:
#Demonstrating how to change marker size (s=10) and shape (marker='x')
plt.scatter(color, magnitude, s=10, marker='x') 
plt.xlabel('B-V (color)', size=12)
plt.ylabel('Vmag (magnitude)', size=12)
#Inverting the y axis
#This will also work for the x axis, if you just change out "y" for "x"
#You could also achieve this effect by setting the y-limits in reverse
plt.gca().invert_yaxis()
plt.show()

## Advanced techniques

Now, we'll explore a few more advanced <code>matplotlib</code> techniques. These are just the tip of the iceberg -- remember, if you can dream it, <code>matplotlib</code> can probably do it! In fact, <code>matplotlib</code> supports many more types of plots than just the four shown above. Check out [this comprehensive list](https://matplotlib.org/stable/plot_types/index.html) for more!

### Colorbars

A colorbar allows you to map colors to scalar values, i.e., color-code your data points by a third parameter. This can be a nice way to show 3D information in a 2D space. For example, take the color-magnitude we plotted in the scatterplot section above. What if we wanted to see how a third parameter varied in color-magnitude space?

In [None]:
#We add the third variable as the "color" dimensions by writing c=variable
plt.scatter(color, magnitude, c=parallax, s=10, marker='x')
plt.xlabel('B-V (color)', size=12)
plt.ylabel('V_mag (magnitude)', size=12)
plt.gca().invert_yaxis()
#Then we add a colorbar to the plot with this call:
plt.colorbar(label='parallax') #label='name' will append a label to the axis
plt.show()

In [None]:
#matplotlib provides various colormaps that you can use besides the default 'viridis'
#For reference: https://matplotlib.org/stable/users/explain/colors/colormaps.html
plt.scatter(color, magnitude, c=parallax, s=10, marker='x', cmap='magma')
plt.xlabel('B-V (color)', size=12)
plt.ylabel('V_mag (magnitude)', size=12)
plt.gca().invert_yaxis()
plt.colorbar(label='parallax')
plt.show()

### Heatmaps

There are a couple options for making heatmaps in <code>matplotlib</code>. Which one you choose to use depends on your data. For example, if we have image-like data, where each location in the heatmap can be treated as a pixel with a certain intensity, we want to use <code>plt.imshow()</code>.

In [None]:
#Generating a fake image
np.random.seed(19680801) #Fixing random state for reproducibility
rand_image = np.random.random((100, 100))

In [None]:
#Making a simple heatmap
#For reference: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.imshow.html
plt.imshow(rand_image, cmap='magma') #We can change the colormap just like in the previous section
plt.colorbar()
plt.show()

You can also use <code>plt.hist2d()</code> to make a 2D histogram of your data, which differs from the above case in that you define the "pixel" boundaries in coordinates that are relevant to your data:

In [None]:
#Making a 2D histogram
#For reference: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.hist2d.html
plt.hist2d(color, magnitude, bins=50, range=[[-0.2, 2], [2,13]], cmap='gray_r')
plt.xlabel('B-V (color)', size=12)
plt.ylabel('V_mag (magnitude)', size=12)
plt.gca().invert_yaxis()
plt.show()

### Subplots

What if I want to make two plots side by side? Or make some other complex layout? <code>matplotlib</code> allows you to do this with subplots. 

<span style="color:red">IMPORTANT NOTE: <code>matplotlib</code> subplots use a slightly different syntax from the commands we've been using so far. Rather than just using <code>plt.</code> calls, we will have to work with <code>Figure</code> and <code>Axes</code> objects. The below code demonstrates a simple way of doing this, but you will find that there are many ways to work with subplots, and you should explore on your own as the need arises!</span>

The syntax we'll use for making subplots is <code>plt.subplots(num_rows, num_columns, figsize=(length,height))</code>. First, we'll show how you can recreate a single plot like the ones we've made so far with the subplots syntax. See the example below and note the changes:

1. <code>plt.scatter()</code> becomes <code>axes.scatter()</code>
2. <code>plt.xlabel()</code> becomes <code>axes.set_xlabel()</code> (and the same for the y-label)
3. The syntax for creating the colorbar is slightly different; we have to store the results of our plotting command and pass them into the colorbar command. In general, colorbars get confusing when we start working with <code>Figure</code> and <code>Axes</code> objects. You should look at the documentation, e.g. [this page](https://matplotlib.org/stable/users/explain/axes/colorbar_placement.html), for more advanced use cases. 

In [None]:
#Recreating a single plot with the new syntax
#For reference: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.subplots.html
fig, axes = plt.subplots(1, 1, figsize=(9,6)) # fig is our Figure object and axes is our Axes object
im = axes.scatter(color, magnitude, c=parallax, s=10, marker='x', cmap='magma') 
axes.set_xlabel('B-V (color)', size=12)
axes.set_ylabel('Vmag (magnitude)', size=12)
axes.invert_yaxis()
plt.colorbar(im) 
plt.show()

Now we'll actually make two plots side-by-side:

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12,6)) 

#Making the first plot
axes[0].scatter(color, magnitude, s=10, marker='x')
axes[0].set_xlabel('B-V (color)', size=12)
axes[0].set_ylabel('V_mag (magnitude)', size=12)
axes[0].set_xlim(-0.2, 2)
axes[0].set_ylim(13, 2)

#Making the second plot
axes[1].hist2d(color, magnitude, bins=50, range=[[-0.2, 2], [2,13]], cmap='gray_r')
axes[1].set_xlabel('B-V (color)', size=12)
axes[1].set_ylabel('V_mag (magnitude)', size=12)
axes[1].set_xlim(-0.2, 2)
axes[1].set_ylim(13, 2)

plt.show()

## Your turn! Let's practice!
Load in the provided dataset (<code>transit_data.dat</code>) and see if you can visualize the transit of a planet around a star. The columns in this dataset are not clearly labeled, but the first column is time (in unfamiliar units), the second is flux (or brightness of the star), and the third is the error on the flux. 

Try to make the following plots:
1. Flux vs time line plot
2. Flux vs time scatter plot
3. Flux vs time including errorbars (you should use google and matplotlib documentation to give you hints on how to do this).
4. Make the above plots visually appealing and manipulate them so that they clearly show that there is a planet in this system. It might be hard to see the transits at first -- you should adjust the axes to zoom in on one of them. Look for places where clusters of points seem to dip down below the usual pattern of the light curve.

To make these plots, you'll first need to figure out how to separate the time array from the flux array from the flux error array. This will require knowledge of indexing of <code>numpy</code> arrays, which we discussed last week.

<span style="color:red">IMPORTANT NOTE: You must change the path below to point to the location of the <code>transit_data.dat</code> file on your system! And remember, if it's in the same folder as this notebook, you can just use a local path to access it.</span>

In [None]:
#Change this path:
data_path = 'your/path/here/transit_data.dat'

#Loading in the data
transit_data = np.loadtxt(data_path)
transit_data

In [None]:
#Write your code in this cell (and add more cells as needed!)

*The content of this notebook was inspired by an earlier version made by former Columbia astronomy PhD student Courtney Carter.*