# Advanced Plotting with Matplotlib

## PHYS 240
## Dr. Wolf

# But first, imports

In [None]:
import matplotlib.pyplot as plt
import numpy as np
plt.style.use('default')

# Basic Example: Plotting Hyperbolic functions in `matplotlib`:
Recall (or behold, if you haven't seen them before!) the hyperbolic trigonometric functions, $\sinh$ and $\cosh$:

$$\sinh(x) = \frac{e^x-e^{-x}}{2}\qquad \cosh(x) = \frac{e^x + e^{-x}}{2}$$

`numpy` provides shortcut functions for these. Let's first visualize them with `matplotlib` like we always have:

In [None]:
xs = np.linspace(-3, 3, 1000)
ys1 = np.sinh(xs)
ys2 = np.cosh(xs)
plt.plot(xs, ys1, label=r'$\sinh(x)$')
plt.plot(xs, ys2, ls='--', label=r'$\cosh(x)$')
plt.xlabel('$x$'); plt.ylabel('$f(x)$'); plt.legend(loc='best')

# State Machine vs. Object-Oriented Approach
In using `pyplot` (`plt`) we are implicitly using the "state machine" of `matplotlib`.
- There is one figure that `plot`, `legend`, `xlabel`, etc. implicitly "go to"
- To create a new figure for subsequent calls to go to, need to call `plt.figure()`, but then it is hard to get a handle on the older figure
- `matplotlib` has a sense of "state"; what the current figure is and how it is set up

# State Machine vs. Object-Oriented Approach
We'll move to using the object-oriented approach. All aspects of a figure are instances of various classes and they each have their own attributes and methods
- `Figure` encompasses an entire figure of any kind, plots or otherwise
- `Axes` is a coordinate axes, including any plotted data; this is what we usually want to work with
- `Line2D` is an object that controls a single line-plot's worth of data on a given `Axes` instance

There are a dizzying number of distinct classes that make up figures, but the most important are `Figure` and `Axes`.

# Object-Oriented Approach: Creating Figures
A figure object is instantiated by `pyplot`'s `figure()` constructor. There are many optional arguments to this function, but the most useful is probably `figsize`, which you can set to a tuple which gives the dimensions (length, width) of the desired figure, in inches.

So the following would create a figure that is 6 inches wide by 4 inches tall and stores it in `fig`:

```python
fig = plt.figure(figsize=(6, 4))
```

# Object-Oriented Approach: Adding Axes to a Figure
Once a figure is instantiated, we can add an `Axes` object, which is what we can use to plot data. The classic way to do this is to call the `add_subplot` method of the `Figure` instance you want to add the axes to.

We can see how we are already "naming" objects so that we can get at them, possibly juggling multiple axes (panels) in multiple figures.

`add_subplot` takes three integer arguments (sort of). They represent the number of rows of subplots there will be in the figure, then the number of columns of subplots there will be, and finally the index (starting at 1 for some reason) of the plot we actually want to create. Indexing goes from top to bottom, left to right.

**Shortcut:** Rather than using three integers, use a single three digit integer, with the hundreds, tens, and ones places representing the three original digits.

# Example: Creating figures and axes

In [None]:
# create the figure
fig = plt.figure()

# create left axis object in a 1-row, 2-columns setup
ax1 = fig.add_subplot(1, 2, 1)
# create right axis object in a 1-row, 2-columns setup using the `add_subplot` shortcut
ax2 = fig.add_subplot(122)

# Shortcut: Create the figure and all subplots in one call!
Not well-known enough: `plt.subplots`. This function combines aspects of `plt.figure` and the `add_subplot` method of figures, ultimately returning a figure and a `numpy` array of axis objects, which are instead **indexed from zero**. The function takes two integers, which are the same as the first two arguments to `add_subplot`, but it has many other optional arguments. If there are only one row and one column, the return value will just be a tuple of a figure and a single axis object.

In [None]:
# create figure (fig) and an ARRAY of axes objects (axes) over three rows and one column
fig, axes = plt.subplots(3, 1, figsize=(6, 6))

# Subplots/Axes: The Main Event!
- Axes objects are the ones that can accept `plot` and `scatter` method calls.
- Lines, labels, and legends are all attached to axes objects, and _most_ of the functions are the same as in `pyplot`'s state machine approach.
- Main difference is labels (`xlabel()` → `set_xlabel()`) and limits (`xlim()` → `set_xlim()`).

# Example: Hyperbolic trig functions revisited: Object-Oriented Approach (Solution at end)

In [None]:
xs = np.linspace(-3, 3, 1000)
ys1 = np.sinh(xs)
ys2 = np.cosh(xs)

# do plotting together in class

# Challenge: 3 Harmonics in 3 Panels (Solution at end)
Goal: plot the three functions $f_1(x) = \sin(x)$, $f_2(x) = \sin(2x)$, and $f_3(x) = \sin(3x)$ in three vertically-stacked panels, all from $x=0$ to $x=2\pi$.

**Pro Tip**: Specifying `sharex=True` in call to `subplots` will force stacked plots to share the same ticks and tick labels, regardless the range over which they are plotted

In [None]:
# first, create the data

# Make figure and axes: 3 rows, one column!

# plot each function in its own axis


# Focus: Customizing Dashstyles
We've seen that you can specify a linestyle for a `plot` command using simple strings like `'-'` (solid; default), `'--'` (dashed), `':'` (dotted), and `'-.'` (dash-dotted). But you can actually customize these to arbitrary pattern by providing a value to the optional `dashes` keyword argument.

`dashes` takes an iterable that gives the lengths of "on" and "off" patterns, measured in points. For example,

```python
ax.plot(xs, ys, dashes=[3, 1.5, 1.5, 1.5])`
```
will plot with a pattern of 3 points drawn, 1.5 points off, 1.5 on, and another 1.5 off, and then the pattern repeats. This is essentially a dot-dash pattern.

# Example: Lots of Dashes!
If you find yourself in a situation where you need this many linestyles in a single figure, you should first reconsider some life choices. Every now and then, though, it makes sense to have this many linestyles.

In [None]:
dashstyles = ('', [3, 1.5], [1.5, 1.5], [3, 1.5, 1.5, 1.5], [3, 1.5, 1.5, 1.5, 1.5, 1.5])
fig, ax = plt.subplots(1, 1)
for i, ds in zip(range(1, len(dashstyles) + 1), dashstyles):
    ax.plot(xs, np.sin(i * xs), dashes=ds, label=r'$\sin(' + str(i) + 'x)$')
ax.legend(loc='best')


# Focus: Adding Markers to Line Plots
We know that we can make scatter plots (markers only; no connecting lines) by using the `scatter` method of axes, but we can also include markers in a standard plot by specifying the `marker` keyword argument. By default, it will add a marker at *every point*, which is probably too often, so you can also specify `markevery`, which will only plot a marker for every $n$ points, for some chosen value of $n$.

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4.5,4.5))
ax.plot(xs, ys1, marker='o', markevery=80, label=r'$\sinh(x)$')
ax.plot(xs, ys2, marker='^', markevery=80, ls='--', label=r'$\cosh(x)$')
ax.set_xlabel('$x$'); ax.set_ylabel('$f(x)$'); ax.legend(loc='best')

# Complicated Example: Bubble Plots Use Color and Size
## We won't cover this in class, but you should check it out on your own time.
Below I implement Hill's 7.4.3, with some stylistic choices of my own.

The goal is to plot the average body mass index (BMI) for men in various countries against those countries' gross domestic product (GDP) per capita. Additionally, we want to represent the relative populations of the countries with the size of the markers, or "bubbles" for each country. Finally, we want to color-code each country's marker according to which continent it came from.

In the next three cells we:
1. Read in the necessary data from several tab-separated-value (tsv) files to create `numpy` record arrays (and also set up a dictionary from continent name to the default colors of `matplotlib`)
2. Combine all of that data into one single record array that contains all necessary data (country name, GDP per capita, male BMI, population, and continent)
3. Plot that data as a scatter plot, but varying the sizes of the markers according to population, and varying the color of each marker by the appropriate continent.

In [None]:
# load bmi data
dt = np.dtype([('country', 'U60'), ('bmi', 'float64')])
bmi_data = np.genfromtxt('bmi_men.tsv', dtype=dt, delimiter='\t')

# load gdp data
dt2 = np.dtype([('country', 'U60'), ('gdp', 'float64')])
gdp_data = np.genfromtxt('gdp.tsv', dtype=dt2, delimiter='\t', filling_values=-1, skip_footer=1)

# load population data (try on your own first!)
dt3 = np.dtype([('country', 'U60'), ('population', 'float64')])
pop_data = np.genfromtxt('population_total.tsv', dtype=dt3, delimiter='\t', filling_values=-1, skip_footer=1)

# load continent data
# witchcraft to get a list of the default colors
lines_color_cycle = [p['color'] for p in plt.rcParams['axes.prop_cycle']]
dt4 = np.dtype([('country', 'U60'), ('continent', 'U60')])
con_data = np.genfromtxt('continents.tsv', dtype=dt4, delimiter='\t')
continents = list(set(con_data['continent']))
continent_colors = dict(zip(continents, lines_color_cycle[:len(continents)]))
continent_colors


In [None]:
# this cell just creates one record array with all countries that have valid data in all categories
# (bmi, gdp, population, and continent), so that they can be easily accessed and consistently used
# in the plotting (next cell)

# first, go through each country and assemble its bmi, gdp, population, and color,
# creating a list of tuples
records = []
for country, bmi in bmi_data:
    if country in gdp_data['country'] and country in pop_data['country'] and country in con_data['country']:
        gdp_loc = np.where(gdp_data['country'] == country)[0]
        pop_loc = np.where(pop_data['country'] == country)[0]
        con_loc = np.where(con_data['country'] == country)[0]
        gdp = gdp_data['gdp'][gdp_loc]
        pop = pop_data['population'][pop_loc]
        # ran into strange problem where the "location" for a continent
        # still returned an array, so catch that and just de-reference
        # it one more time... I was too lazy to properly figure out
        # what was going on
        try:
            color = continent_colors[con_data['continent'][con_loc]]
        except TypeError as e:
            color = continent_colors[con_data['continent'][con_loc[0]]]
        if gdp > 0 and pop > 0:
            records.append((country, bmi, gdp, pop, color))

# now convert the list of tuples into a record array
dt5 = np.dtype([('country', 'U60'), ('bmi', 'float64'), ('gdp', 'float64'), ('population', 'float64'), ('color', 'U60')])
all_data = np.array(records, dtype = dt5)
    

In [None]:
fig, ax = plt.subplots(1, 1)
# all the magic in one line! scatter plot of bmi vs gdp, but get sizes of markers from the population,
# scaled so it looks nice to the eye, and setting the color according to the continent
# note also the use of alpha, which adds some trasparency to the bubbles so we can "see through" them
ax.scatter(all_data['gdp'], all_data['bmi'], s=all_data['population'] / max(all_data['population']) * 1000, c=all_data['color'], alpha=0.4)
ax.set_ylabel('Average Male BMI [kg/m$^2$]')
ax.set_xlabel('GDP per capita [\$]')

# bogus plots to get a useful legend
for cont, color in continent_colors.items():
    ax.scatter([], [], c=color, alpha=0.4, label=cont)
ax.legend(loc='best')

# using log scale makes the diversity among the low gdp nations more clear
ax.set_xscale('log')

# chop off some high-bmi outliers (maybe not a good idea, but it looks nicer!)
ax.set_ylim(top=30)

# Customizing Tickmarks
The **ticks** are the little lines along the axis edges that set the scale and values in a plot. `matplotlib` is usually pretty good about making these look decent, but you can customize them to an arbitrary precision using these methods of axes objects:
- `set_yticks` and `set_xticks` (sets the positions of major or minor tickmarks)
- `set_yticklabels` and `set_xticklabels` (sets the labels for major or minor tickmarks)
- `tick_params` (change appearance of major and/or minor tick marks)

# Example: Tinkering with Tickmarks
Here we want to create a figure showing a generic sine curve, without specifying a wavelength or amplitude. That is, we want to show a generic plot of

$$f(x) = A\sin\left(\frac{2\pi x}{\lambda}\right)$$

In [None]:
xs = np.linspace(0, 1, 200)
# implicitly: lambda is 1 and A is one
ys = np.sin(2*np.pi * xs)
fig, ax = plt.subplots(1, 1)
ax.plot(xs, ys)
ax.set_xlabel('$x$')
ax.set_ylabel('$f(x)$')

# TODO: change positions and labels of tick marks to be more "generic"

# TODO: make tick marks point inward (my personal preference)

# show arbitrary label (not a legend); coming up soon!
ax.text(0.95, 0.95, r'$f(x) = A\sin\left(\frac{2\pi x}{\lambda}\right)$', ha='right', va='top', transform=ax.transAxes)

# Errorbars
Most experimental data is meaningless without associated uncertainties. On plots, we represent uncertainties with error bars, expressing the region of confidence around a data point.

`matplotlib` axes objects have the powerful `errorbar` function, which plots a line plot, optionally with markers, and can also specify uncertainties in either the $x$ or $y$ directions in the following ways:
- No uncertainty (no errorbar)
- Constant uncertainty for all values
- Varying, but symmetric (plus or minus) uncertainties
- Asymmetric and varying (different positive and negative) undertaintes

# Example: "Noisy Drop the Rock"
We want to create simulated data for measuring the position of a falling object as a function of time. We know that an object dropped from rest should roughly follow the equation

$$y(t) = y_0 - \frac{1}{2}gt^2$$

Suppose we could only measure the time accurate to within one tenth of a second, and we could only measure the position to within a 25 centimeters (we're really bad at that, I guess). Let's first generate some bogus experimental data by adding some random noise, and then we'll plot it using `errorbar`.

In [None]:
num_points = 15  # number of measurements
t_unc = 0.1  # uncertainty in time measurements, in seconds
y_unc = 0.25  # uncertainty in position measurements, in meters
y_0 = 10.0  # initial height, in meters
g = 9.81  # acceleration due to gravity, in m/s^2

# np.random.random creates random values between 0 and 1, so we subtract a half, and
# multiply by 2 to get a random number between -1 and 1, and then multiply by the
# uncertainty
expected_times = np.linspace(0,1.5, num_points)
noisy_times = expected_times + 2 * (np.random.random(num_points) - 0.5) * t_unc

expected_ys = y_0 - 0.5 * g * expected_times**2
noisy_ys = expected_ys + 2 * (np.random.random(num_points) - 0.5) * y_unc

In [None]:
plt.style.use('default')
# plt.style.use('custom.mplstyle')
fig, ax = plt.subplots(1, 1)
# note order of inputs is xs, ys, uncertainty in ys (NOT XS!), then uncertainty in xs.
# to get rid of line connecting points, set fmt to just be circular markers
# to get caps (lines at the ends of the errorbars), set capsize to length of the caps in pts
ax.errorbar(noisy_times, noisy_ys, y_unc, t_unc, capsize=3, fmt='o', ecolor='black', label='Noisy Data')
ax.plot(expected_times, expected_ys, label='Perfect Data')
ax.set_ylabel('Height [m]')
ax.set_xlabel('Time [s]')
ax.legend(loc='best')

# Practical Tips
- **Use Examples to Learn**:
    - Google "matplotlib errboar asymmetric" or something similar; usually an excellent stack overflow question
    - Use [matplotlib gallery](https://matplotlib.org/stable/gallery/index.html) for inspiration
- **Create a styleguide**
    - Creat your own `matplotlibrc` file that you include in the same directory as projects
    - Get basic one [here](https://matplotlib.org/stable/tutorials/introductory/customizing.html#a-sample-matplotlibrc-file). Copy into a file with the name `matplotlibrc` in the same directory, and `matplotlib` will use the styles you define there as the default.
- **Don't Mess with Font Sizes**: Change the figure size instead

# Labeling Plots: Philosophical Points
- A few well-chosen words and symbols can ensure a figure is always interpreted in context
- Don't add piles of text, though... just enough to provide context, like what you might include in a legend label
- Particularly good application: qualitatively labeling panels in a multi-panel plot

# Labeling Plots: Text Only with `ax.text`
### Key parameters:
- `x` and `y`: coordinates where text should be placed
- `s`: string that is the text to be added to the plot

### Other useful parameters:
- `bbox`: dictionary of paramters describing a box around the text (look this up)
- `ha`: horizontal alignment of text (`'center`', `'left`', or `'right'`)
- `va`: vertical alignment of text (`'center'`, `'top'`, `'bottom'`, `'baseline'`, or `'center_baseline'`)
- `transform`: determines how coordinates should be interpreted (are they relative to data, the width of the figure, the width of the axes? This is **important**!)


# Challenge: Labeling a Critical Point (Solution at end)
Let's locate the critical points of the polynomial $f(x) = 2x^3 + x^2 - 8x -5$, which has a derivative of $f^\prime(x) = 6x^2 + 2x - 8 = (2x - 2)(3x + 4)$, and thus has critical points at $x=1$ and $x = -4/3$.

The code below does exactly this, placing markers at the critical points and labeling them. Execute it, but then notice... the labels are a little crowded, aren't they? Change it so the first critical point label is centered above the marker, and the second is centered below the marker. You might also need to manually adjust the plotting limits to accommodate the new labels.

In [None]:
f = lambda x: 2 * x**3 + x**2 - 8 * x - 5
xs = np.linspace(-2, 2)
ys = f(xs)
x_crits = [-4.0/3.0, 1]
y_crits = [f(x) for x in x_crits]
fig, ax = plt.subplots(figsize=(4, 4))
ax.plot(xs, ys)
ax.scatter(x_crits, y_crits, marker='o', color='Black')
# TODO: fix alignment to make this pretty!
ax.text(x_crits[0], y_crits[0], '$x_0 = -4/3$')
ax.text(x_crits[1], y_crits[1], '$x_0 = 1$')

# In axes coordinates, the lower left corner is (0,0) and the upper right is (1,1)
Axes coordinates are...
- decoupled from the actual data or limits on the subplot
- useful when you want a qualitative for the plot that's not associated with a particular point
- used in `ax.text` when you set the optional keyword argument `transform=ax.transAxes`
- used for annotation in a frustratingly different way (more on that soon)

# Extended Example: Adding a label for the function (Complete at end)
Goal: Add a label to the upper right corner that shows the function, properly typeset.

In [None]:
# xs and ys should still be set appropriately according to f(x) = 2x^2 + x^2 - 8x - 5
f = lambda x: 2 * x**3 + x**2 - 8 * x - 5
fig, ax = plt.subplots(figsize=(4, 4))
ax.plot(xs, ys)
ax.scatter(x_crits, y_crits, marker='o', color='Black')
ax.text(x_crits[0], y_crits[0], '$x_0 = -4/3$',ha='center', va='bottom')
ax.text(x_crits[1], y_crits[1], '$x_0 = 1$',ha='center', va='top')
ylo, yhi = ax.get_ylim() # min(ys), max(ys)
height = yhi - ylo
margin = 0.05 * height
ax.set_ylim(ylo - margin, yhi + margin)

# TODO: Add label in the upper right hand corner giving the functional form


# Annotations are labels with arrows
- Use `ax.annotate`, which is more complex than `ax.text`
- Frustratingly different syntax than `text`
- Need **two locations**, the position of the text, and the location the arrow should point to.

# Calling Sequence:
- text comes **first**, called `s`
- then, coordinates of point to be annotated as a **tuple** called `xy`
- then, coordinates for the location of the text, in `xytext`
- rather than specifying a `transform`, we specify `xycoords` to be one of `'data`', `'figure fraction'`, or `axes fraction`'. Sets the interpretation of `xy`
- `textcoords` does the same, but for `xytext`
- Finally, you can customize the appearance of the arrows with a dictionary called `arrowprops`
```python
ax.annotate('my label', 1, 0, 2, 2, 'data', 'data', arrowprops={'arrowstyle': '->'}
```
This creates a text annotation at (2, 2) with the text "my label" and an arrow that points from the text to the point (1, 0).

# Very Basic Annotation in Action

In [None]:
fig, ax = plt.subplots(figsize=(3.4, 3.4))
ax.set_xlim(0, 3)
ax.set_ylim(0, 3)
ax.annotate('my label', (1, 1), (2, 2), 'data', 'data', arrowprops={'arrowstyle': '->'})

# Example: Labels to Annotations (Complete at end)
Goal: convert the labels near critical points to annotations with labels.

In [None]:
# xs and ys should still be set appropriately according to f(x) = 2x^2 + x^2 - 8x - 5
f = lambda x: 2 * x**3 + x**2 - 8 * x - 5
fig, ax = plt.subplots(figsize=(4, 4))
ax.plot(xs, ys)
ax.scatter(x_crits, y_crits, marker='o', color='Black')
# TODO: Adjust these to be annotations
ax.text(x_crits[0], y_crits[0], '$x_0 = -4/3$',ha='center', va='bottom')
ax.text(x_crits[1], y_crits[1], '$x_0 = 1$',ha='center', va='top')
ylo, yhi = min(ys), max(ys)
height = yhi - ylo
ax.set_ylim(ylo - 0.10 * height, yhi + 0.10 * height)

# Add label in the upper right hand corner giving the functional form (done before)
ax.text(0.98, 0.98, '$f(x) = 2x^2 + x^2 - 8x - 5$', ha='right', va='top', transform=ax.transAxes)


# Reference: The Completed Plot
Here we chose to use arrows to show how you can customize them. We also used `'axes fraction'` as the text coordinates option, to easily place them on the axis canvas wherever we wanted, regardless of the data. You might (and probably did) make some other choices, and that's okay.

In [None]:
# generate plot data
f = lambda x: 2 * x**3 + x**2 - 8 * x - 5
xs = np.linspace(-2, 2)
ys = f(xs)

# hard-code in critical point locations
x_crits = [-4.0/3.0, 1]
y_crits = [f(x) for x in x_crits]

# line plot of main data; scatter plot locations of critical points
fig, ax = plt.subplots(figsize=(4, 4))
ax.plot(xs, ys)
ax.scatter(x_crits, y_crits, marker='o', color='Black')

# add in annotations
ax.annotate('$x_0 = -4/3$', (x_crits[0], y_crits[0]), (0.1, 0.7), 'data', 'axes fraction', arrowprops={'arrowstyle': '->', 'shrinkB': 4, 'shrinkA': 0})
ax.annotate('$x_0 = 1$', (x_crits[1], y_crits[1]), (0.65, 0.3), 'data', 'axes fraction',  arrowprops={'arrowstyle': '->', 'shrinkB': 4, 'shrinkA': 0})

# add in function label
ax.text(0.98, 0.98, '$f(x) = 2x^2 + x^2 - 8x - 5$', ha='right', va='top', transform=ax.transAxes)

# Example: Hyperbolic trig functions revisited: Object-Oriented Approach SOLUTION

In [None]:
# first set up the figure and axes
fig, ax = plt.subplots(1, 1)

# plotting calls and other adjustments are methods of the
# axes object rather than a funciton from pyplot
ax.plot(xs, ys1, label=r'$\sinh(x)$')
ax.plot(xs, ys2, ls='--', label=r'$\cosh(x)$')
ax.set_xlabel('$x$')
ax.set_ylabel('$f(x)$')
ax.legend(loc='best')

# Challenge: 3 Harmonics in 3 Panels Solution
Goal: plot the three functions $f_1(x) = \sin(x)$, $f_2(x) = \sin(2x)$, and $f_3(x) = \sin(3x)$ in three vertically-stacked panels, all from $x=0$ to $x=2\pi$.

**Pro Tip**: Specifying `sharex=True` in call to `subplots` will force stacked plots to share the same ticks and tick labels, regardless the range over which they are plotted

In [None]:
# first, create the data
xs = np.linspace(0, 2 * np.pi, 1000)
ys1 = np.sin(xs)
ys2 = np.sin(2 * xs)
ys3 = np.sin(3 * xs)

# Make figure and axes: 3 rows, one column!
fig, axes = plt.subplots(3, 1, sharex=True)

# plot each function in its own axis
# NOTE: I've added some trickery here to get all three lines to
# appear on a single legend. Such things are necessary since
# a legend "belongs" to a single axis
axes[0].plot(xs, ys1, label=r'$\sin(x)$')

# bogus empty plots to create a valid legend. Note how
# the linestyles and colors match that of the second and third
# panels
axes[0].plot([], [], ls='--', color='C1', label=r'$\sin(2x)$')
axes[0].plot([], [], ls=':', color='C2', label=r'$\sin(3x)$')

# actual lower panels for the higher harmonics
axes[1].plot(xs, ys2, ls='--', color='C1')
axes[2].plot(xs, ys3, ls=':', color='C2')

# add the legend to the top panel only
axes[0].legend(loc='best')

# label the axes. Skip x-axes for the top two panels since
# it is redundant with the bottom panel.
for ax in axes:
    ax.set_ylabel('$f(x)$')
axes[-1].set_xlabel(r'$x$')

# Example: Tinkering with Tickmarks (Complete)
Here we want to create a figure showing a generic sine curve, without specifying a wavelength or amplitude. That is, we want to show a generic plot of

$$f(x) = A\sin\left(\frac{2\pi x}{\lambda}\right)$$

In [None]:
xs = np.linspace(0, 1, 200)
# implicitly: lambda is 1 and A is one
ys = np.sin(2*np.pi * xs)
fig, ax = plt.subplots(1, 1)
ax.plot(xs, ys)
ax.set_xlabel('$x$')
ax.set_ylabel('$f(x)$')

# TODO: change positions and labels of tick marks to be more "generic"
ax.set_xticks([0, 0.25, 0.5, 0.75, 1])
ax.set_yticks([-1, -0.5, 0, 0.5, 1])
ax.set_xticklabels(['${}$'.format(label) for label in (0, r'\lambda/4', r'\lambda/2', r'3\lambda/4', r'\lambda')])
ax.set_yticklabels(['${}$'.format(label) for label in ('-A', '-A/2', '0', 'A/2', 'A')])

# TODO: make tick marks point inward (my personal preference)
ax.tick_params(axis='both', direction='in')

# show arbitrary label (not a legend); coming up soon!
ax.text(0.95, 0.95, r'$f(x) = A\sin\left(\frac{2\pi x}{\lambda}\right)$', ha='right', va='top', transform=ax.transAxes)

# Challenge: Labeling a Critical Point (Solution)
Let's locate the critical points of the polynomial $f(x) = 2x^3 + x^2 - 8x -5$, which has a derivative of $f^\prime(x) = 6x^2 + 2x - 8 = (2x - 2)(3x + 4)$, and thus has critical points at $x=1$ and $x = -4/3$.

The code below does exactly this, placing markers at the critical points and labeling them. Execute it, but then notice... the labels are a little crowded, aren't they? Change it so the first critical point label is centered above the marker, and the second is centered below the marker. You might also need to manually adjust the plotting limits to accommodate the new labels.

In [None]:
f = lambda x: 2 * x**3 + x**2 - 8 * x - 5
xs = np.linspace(-2, 2)
ys = f(xs)
x_crits = [-4.0/3.0, 1]
y_crits = [f(x) for x in x_crits]
fig, ax = plt.subplots(figsize=(4, 4))
ax.plot(xs, ys)
ax.scatter(x_crits, y_crits, marker='o', color='Black')

# the new alignments make things better, but they're still not perfect.
# In particular the second one is still quite close to the point itself.
# the `annotate` method of axes might be a better choice, but it is more
# complicated to implement.
ax.text(x_crits[0], y_crits[0], '$x_0 = -4/3$', va='bottom', ha='center')
ax.text(x_crits[1], y_crits[1], '$x_0 = 1$', va='top', ha='center')

# ensure a 5% margin at the top and bottom... scales gracefully!
ylo, yhi = ax.get_ylim()
height = yhi - ylo
margin = 0.05 * height
ax.set_ylim(ylo - margin, yhi + margin)

# Extended Example: Adding a label for the function (Complete at end)
Goal: Add a label to the upper right corner that shows the function, properly typeset.

In [None]:
# xs and ys should still be set appropriately according to f(x) = 2x^2 + x^2 - 8x - 5, but just in case, we re-implement them here
f = lambda x: 2 * x**3 + x**2 - 8 * x - 5
xs = np.linspace(-2, 2)
ys = f(xs)
x_crits = [-4.0/3.0, 1]
y_crits = [f(x) for x in x_crits]


fig, ax = plt.subplots(figsize=(4, 4))
ax.plot(xs, ys)
ax.scatter(x_crits, y_crits, marker='o', color='Black')
ax.text(x_crits[0], y_crits[0], '$x_0 = -4/3$',ha='center', va='bottom')
ax.text(x_crits[1], y_crits[1], '$x_0 = 1$',ha='center', va='top')
ylo, yhi = ax.get_ylim() # min(ys), max(ys)
height = yhi - ylo
margin = 0.05 * height
ax.set_ylim(ylo - margin, yhi + margin)

# TODO: Add label in the upper right hand corner giving the functional form
# x and y locations are essentially fractions of the figure width/height
# where we want the label to be. So this will e 98% to the right and 98%
# towards the top. We also set the alignment so that we place the 
# upper right corner at that location, and finally, the transform
# forces `text` to use figure coordinates rather than data coordinates.
ax.text(0.98, 0.98, r'$f(x) = 2x^3+x^2-8x-5$', ha='right', va='top', transform=ax.transAxes)

# Example: Labels to Annotations (Solution)
Goal: convert the labels near critical points to annotations with labels.

In [None]:
# xs and ys should still be set appropriately according to f(x) = 2x^2 + x^2 - 8x - 5
f = lambda x: 2 * x**3 + x**2 - 8 * x - 5
fig, ax = plt.subplots(figsize=(4, 4))
ax.plot(xs, ys)
ax.scatter(x_crits, y_crits, marker='o', color='Black')
# TODO: Adjust these to be annotations
# Note that the locations of the annotation point remain the same, but now we specify some
# text coordinates: (0, 5) and then (0, -5). By then specifying that these are meant to be
# interpreted as "offsets in points", this means they will be 5 points above and 5 points
# below (respectively) the actual annotation points. This is a powerful way to get arbitrary
# offsets without having to fiddle with the data coordinates.
#
# And since we don't specify any arrow properties, there is no arrow connecting the text to
# the annotation location.
ax.annotate('$x_0 = -4/3$', (x_crits[0], y_crits[0]), (0, 5), xycoords='data', textcoords='offset points', ha='center')
ax.annotate('$x_0 = 1$', (x_crits[1], y_crits[1]), (0, -5), xycoords='data', textcoords='offset points', ha='center', va='top')
ylo, yhi = min(ys), max(ys)
height = yhi - ylo
ax.set_ylim(ylo - 0.10 * height, yhi + 0.10 * height)

# Add label in the upper right hand corner giving the functional form
ax.text(0.98, 0.98, '$f(x) = 2x^2 + x^2 - 8x - 5$', ha='right', va='top', transform=ax.transAxes)
