Skip to content
Switch branches/tags
Go to file
Cannot retrieve contributors at this time
title: Making many added variable plots with purrr and ggplot2
author: Ariel Muldoon
date: '2018-01-31'
slug: added-variable-plots
- r
- statistics
- ggplot2
- purrr
draft: FALSE
description: "In this post I show one approach for making added variable plots from a model with many continuous explanatory variables. Since this is done for every variable in the model, I show how to automate the process via functions from package purrr."
```{r setup, echo = FALSE, message = FALSE, purl = FALSE}
knitr::opts_chunk$set(comment = "#")
filename = 'render_toc.R')
Last week two of my consulting meetings ended up on the same topic: making added variable plots.
In both cases, the student had a linear model of some flavor that had several continuous explanatory variables. They wanted to plot the estimated relationship between each variable in the model and the response. This could easily lead to a lot of copying and pasting of code, since they want to do the same thing for every explanatory variable in the model. I worked up some example code showing an approach on how one might automate the task in R with functions and loops, and thought I'd generalize it for a blog post.
## Table of Contents
```{r toc, echo = FALSE, purl = FALSE}
# The basics of added variable plots
Added variable plots (aka partial regression plots, adjusted variable plots, individual coefficient plots), are "results" plots. They are plots showing the estimated relationship between the response and an explanatory variable *after accounting for the other variables in the model*. If working with only two continuous explanatory variables, a 3-dimensional plot could be used in place of an added variable plot (if one likes those sorts of plots `r emo::ji("smiley")`). Once there are many variables in the model, though, we don't have enough plotting dimensions to show how all the variables relate to the response simultaneously and so we often use variable plots to show the fitted relationships.
There are packages available for making added variable plots in R, such as **effects** and **visreg**. I like a bit more flexibility, which I get by making my own plots. To do this I need to extract the appropriate predictions and confidence intervals from the model.
When making an added variable plot, it is fairly standard to make the predictions with all other variables fixed to their medians or means. I use medians today. Note that in my example I'm demonstrating code for the relatively simple case where there are no interactions between continuous variables in the model. Continuous-by-continuous interactions would involve a more complicated set-up for making plots.
# R packages
The main workhorses I'm using today is **purrr** for looping through variables/lists and **ggplot2** for plotting. I also use helper functions from **dplyr** for data manipulation and **broom** for getting the model predictions and standard errors.
```{r, warning = FALSE, message = FALSE}
library(dplyr) # v. 0.8.0
library(ggplot2) # v. 3.2.0
library(purrr) # v. 0.3.1
library(broom) # v. 0.5.1
# The linear model
My example model is a linear model with a transformed response variable, fit using `lm()`. The process works the same for generalized linear models fit with `glm()` and would be very similar for other linear models (although you may have to calculate any standard errors manually).
My linear model is based on five continuous variable from the *mtcars* dataset.
The model I fit uses a log transformation for the response variable, so predictions and confidence interval limits will need to be back-transformed prior to plotting to show the relationship between the variables on the response on the original scale. If you don't have a transformation you can skip this step.
fit1 = lm( log(mpg) ~ disp + hp + drat + wt, data = mtcars)
# The explanatory variables
The approach I'm going to take is to loop through the explanatory variables in the model, create datasets for prediction, get the predictions, and make the plots. I'll end with one added variable plot per variable.
I could write out a vector of variable names to loop through manually, but I prefer to pull them out of the model. In the approach I use the first variable in the output is the response variable. I don't need the response variable here, so I remove it with `-1`.
( mod_vars = all.vars( formula(fit1) )[-1] )
# A function for making a prediction dataset
The first step in making an added variable plot manually is to create the dataset to use for calculating model predictions. This dataset will contain the observed data for the explanatory variable of interest (aka the *focus* variable) with all other variables fixed to their medians.
Below is a function for doing this task. The function takes a dataset, a vector of all the variables in the model (as strings), and the name of the focus variable (as a string). The **dplyr** `*_at()` functions can take strings as input, so I use `summarise_at()` to calculate the medians of the non-focus variables. I bind the summary values to the focus variable data from the original dataset with `cbind()`, since `cbind()` allows recycling.
preddat_fun = function(data, allvars, var) {
sums = summarise_at(data,
vars( one_of(allvars), -one_of(var) ),
cbind( select_at(data, var), sums)
Here's what the result of the function looks like for a single focus variable, "disp" (showing first six rows).
head( preddat_fun(mtcars, mod_vars, "disp") )
# Making a prediction dataset for each variable
Now that I have a working function, I can loop through each variable in the `mod_vars` vector and create a prediction dataset for each one. I'll use `map()` from **purrr** for the loop. I use `set_names()` prior to `map()` so each element of the resulting list will be labeled with the name of the focus variable of that dataset. This helps me stay organized.
The result is a list of prediction datasets, one for each variable in my model.
pred_dats = mod_vars %>%
set_names() %>%
map( ~preddat_fun(mtcars, mod_vars, .x) )
# Calculate model predictions for each variable
Once the prediction datasets are created, the predictions can be calculated from the model and added to each dataset. I do this on the model scale, since I want to make confidence intervals with the standard errors prior to back-transforming.
The `augment()` function from **broom** works with a variety of model objects, including *lm* and *glm* objects. It can take new datasets for prediction via the "newdata" argument. I use it here to add both the prediction and the standard errors of the predictions to each dataset.
I do this task by looping through the prediction datasets with `map()`, first to add the predictions via `augment()` and then to calculate approximate confidence intervals. Since my response was on the log scale, I also back-transform the predictions and confidence interval limits to the original data scale. This step isn't needed for a model without a transformation or link function.
preds = pred_dats %>%
map(~augment(fit1, newdata = .x) ) %>%
lower = exp(.fitted - 2*,
upper = exp(.fitted + 2*,
pred = exp(.fitted) ) )
Here is what the structure of the list elements look like now (showing only the first list element).
# Making one added variable plot
With the predictions successfully made it's time for plotting.
Here is the format of the plots I will be making, plotting the fitted line and showing the data on the x axis with a rug plot. I make the output black and white with larger text via `theme_bw(base_size = 14)` and clean up the axis labels via `labs()`.
ggplot(data = preds$disp, aes(x = disp, y = pred) ) +
geom_line(size = 1) +
geom_ribbon(aes(ymin = lower, ymax = upper), alpha = .25) +
geom_rug(sides = "b") +
theme_bw(base_size = 14) +
labs(x = "Displacement (",
y = "Miles/(US) gallon") +
ylim(10, 32)
# Build a plotting function
Once I've worked out what one plot should look like I can make function for making plots. This is a reasonable approach for when making many plots that all look basically the same.
One problem I anticipated running into when automating the plotting is with the x axis labels. The variable names in the dataset aren't very nice looking. If I want the x axis labels to be more polished in the plots I'll need replacement labels. I decided to make a vector of nicer labels, one label for each focus variable. This vector needs to be the same length and in the same order as the vector of variable names and the list of prediction datasets so each plot gets the correct new axis label.
xlabs = c("Displacement (",
"Gross horsepower",
"Rear axle ratio",
"Weight (1000 lbs)")
My plotting function has three arguments: the dataset to plot, the explanatory variable to plot on the x axis (as a string), and label for the x axis (also as a string).
I will use the `.data` pronoun for passing strings in `aes()`. This approach became available starting with **rlang** version 0.4.0.
If using earlier versions of **ggplot2** and **rlang**, use `aes_string()` with code `aes_string(x = variable, y = "pred")`. Note your variable names for `x` must by [syntactically valid]( to use the `aes_string()` approach.
pred_plot = function(data, variable, xlab) {
ggplot(data, aes(x = .data[[variable]], y = pred) ) +
geom_line(size = 1) +
geom_ribbon(aes(ymin = lower, ymax = upper), alpha = .25) +
geom_rug(sides = "b") +
theme_bw(base_size = 14) +
labs(x = xlab,
y = "Miles/(US) gallon") +
ylim(10, 32)
Here is the plotting function in action. I plot the "disp" variable, which is the first element of the three lists (prediction datasets, variables, axis labels). This looks just like the plot I manually made above.
pred_plot(preds[[1]], mod_vars[1], xlabs[1])
# Making all the plots
The very last step is to make all the plots. Because I want to loop through three different lists (the prediction datasets, the variables, and the axis labels), this can be a done via `pmap()` from **purrr**. The `pmap()` function loops through all three lists simultaneously.
all_plots = pmap( list(preds, mod_vars, xlabs), pred_plot)
## Using the plots
The plots can be printed all at once as above or individually using indexes or list names, using code such as `all_plots[[1]]` or `all_plots$disp`. Plots can also be saved for use outside of R. I show some examples of how to do this in a [different post](
It might be nice to combine these individual plots into a single multi-plot. A faceted plot would be an option, but the approach I've done here in its current form isn't a great one for faceting (although I'm sure it could be modified with that in mind).
The individual plots can be combined into a single figure via function from the **cowplot** package, though, without too much trouble. Note that **cowplot** was opinionated about the theme in the past, which is why I use it without loading it here.
The `plot_grid()` function can take a list of plots, which is what I have. It has a variety of options you might want to explore for getting the plots stitched together nicely.
cowplot::plot_grid(plotlist = all_plots,
labels = "AUTO",
align = "hv")
**Addendum:** Package **rms** makes added variable plots via **ggplot2** and **plotly** along with simultaneous confidence bands for any model type the package works with. That includes linear models and generalized linear models excluding the negative binomial family. This may be a useful place to start if you are working with these kinds of models.
# Just the code, please
```{r getlabels, echo = FALSE, purl = FALSE}
labs = knitr::all_labels()
labs = labs[!labs %in% c("setup", "toc", "getlabels", "allcode", "makescript")]
```{r makescript, include = FALSE, purl = FALSE}
options(knitr.duplicate.label = "allow") # Needed to purl like this
input = knitr::current_input() # filename of input document
output = here::here("static", "script", paste(tools::file_path_sans_ext(input), "R", sep = ".") )
knitr::purl(input, output, documentation = 0, quiet = TRUE)
Here's the code without all the discussion. Copy and paste the code below or you can download an R script of uncommented code [from here](/script/2018-01-31-making-many-added-variable-plots-with-purrr-and-ggplot2.R).
```{r allcode, ref.label = labs, eval = FALSE, purl = FALSE}